blas: imported blas as a subtree

This commit is contained in:
Brendan Tracey
2017-05-23 00:02:42 -06:00
99 changed files with 30343 additions and 0 deletions

46
blas/.travis.yml Normal file
View File

@@ -0,0 +1,46 @@
sudo: required
language: go
# Versions of go that are explicitly supported by gonum.
go:
- 1.5.4
- 1.6.3
- 1.7.3
os:
- linux
- osx
env:
matrix:
- BLAS_LIB=OpenBLAS
- BLAS_LIB=gonum
- BLAS_LIB=Accelerate
- BLAS_LIB=ATLAS
matrix:
exclude:
- os: linux
env: BLAS_LIB=Accelerate
- os: linux
env: BLAS_LIB=ATLAS
- os: osx
env: BLAS_LIB=ATLAS
- os: osx
env: BLAS_LIB=OpenBLAS
# Cache used to persist the compiled BLAS library between CI calls.
cache:
directories:
- .travis/OpenBLAS.cache
# Install the appropriate blas library (if any) and associated gonum software.
install:
- source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/$BLAS_LIB/install.sh
script:
- source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/$BLAS_LIB/test.sh
- test -z "$(gofmt -d .)"
# This is run last since it alters the tree.
- ${TRAVIS_BUILD_DIR}/.travis/check-generate.sh

8
blas/.travis/check-generate.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -ex
go generate github.com/gonum/blas/native
go generate github.com/gonum/blas/cgo
if [ -n "$(git diff)" ]; then
exit 1
fi

View File

@@ -0,0 +1,21 @@
set -ex
# fetch and install ATLAS libs
sudo apt-get update -qq && sudo apt-get install -qq libatlas-base-dev
# fetch and install gonum/blas against ATLAS
export CGO_LDFLAGS="-L/usr/lib -lblas"
go get github.com/gonum/blas
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,71 @@
set -ex
CACHE_DIR=${TRAVIS_BUILD_DIR}/.travis/${BLAS_LIB}.cache
# fetch fortran to build OpenBLAS
sudo apt-get update -qq && sudo apt-get install -qq gfortran
# check if cache exists
if [ -e ${CACHE_DIR}/last_commit_id ]; then
echo "Cache $CACHE_DIR hit"
LAST_COMMIT="$(git ls-remote git://github.com/xianyi/OpenBLAS HEAD | grep -o '^\S*')"
CACHED_COMMIT="$(cat ${CACHE_DIR}/last_commit_id)"
# determine current OpenBLAS master commit id and compare
# with commit id in cache directory
if [ "$LAST_COMMIT" != "$CACHED_COMMIT" ]; then
echo "Cache Directory $CACHE_DIR has stale commit"
# if commit is different, delete the cache
rm -rf ${CACHE_DIR}
fi
fi
if [ ! -e ${CACHE_DIR}/last_commit_id ]; then
# Clear cache.
rm -rf ${CACHE_DIR}
# cache generation
echo "Building cache at $CACHE_DIR"
mkdir ${CACHE_DIR}
sudo git clone --depth=1 git://github.com/xianyi/OpenBLAS
pushd OpenBLAS
sudo make FC=gfortran &> /dev/null && sudo make PREFIX=${CACHE_DIR} install
popd
curl http://www.netlib.org/blas/blast-forum/cblas.tgz | tar -zx
pushd CBLAS
sudo mv Makefile.LINUX Makefile.in
sudo BLLIB=${CACHE_DIR}/lib/libopenblas.a make alllib
sudo mv lib/cblas_LINUX.a ${CACHE_DIR}/lib/libcblas.a
popd
# Record commit id used to generate cache.
pushd OpenBLAS
echo $(git rev-parse HEAD) > ${CACHE_DIR}/last_commit_id
popd
fi
# copy the cache files into /usr
sudo cp -r ${CACHE_DIR}/* /usr/
# install gonum/blas against OpenBLAS
export CGO_LDFLAGS="-L/usr/lib -lopenblas"
go get github.com/gonum/blas
pushd cgo
go install -v -x
popd
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,18 @@
set -ex
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# change to native directory so we don't test code that depends on an external
# blas library
cd native
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,21 @@
set -ex
# This script contains common installation commands for linux. It should be run
# prior to more specific installation commands for a particular blas library.
go get golang.org/x/tools/cmd/cover
go get github.com/mattn/goveralls
go get github.com/gonum/floats
# Repositories for code generation.
go get github.com/gonum/internal/binding
go get github.com/cznic/cc
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,20 @@
set -ex
export CGO_LDFLAGS="-framework Accelerate"
go get github.com/gonum/blas
pushd cgo
go install -v -x
popd
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,24 @@
set -ex
# fetch and install OpenBLAS using homebrew
brew install homebrew/science/openblas
# fetch and install gonum/blas against OpenBLAS
export CGO_LDFLAGS="-L/usr/local/opt/openblas/lib -lopenblas"
go get github.com/gonum/blas
pushd cgo
go install -v -x
popd
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,18 @@
set -ex
# run the OS common installation script
source ${TRAVIS_BUILD_DIR}/.travis/$TRAVIS_OS_NAME/install.sh
# change to native directory so we don't test code that depends on an external
# blas library
cd native
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,17 @@
set -ex
go env
go get -d -t -v ./...
go test -a -v ./...
go test -a -tags noasm -v ./...
if [[ $TRAVIS_SECURE_ENV_VARS = "true" ]]; then bash -c "$GOPATH/src/github.com/$TRAVIS_REPO_SLUG/.travis/test-coverage.sh"; fi
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

View File

@@ -0,0 +1,23 @@
set -ex
#!/bin/bash
# This script contains common installation commands for osx. It should be run
# prior to more specific installation commands for a particular blas library.
go get golang.org/x/tools/cmd/cover
go get github.com/mattn/goveralls
go get github.com/gonum/floats
# Repositories for code generation.
go get github.com/gonum/internal/binding
go get github.com/cznic/cc
# travis compiles commands in script and then executes in bash. By adding
# set -e we are changing the travis build script's behavior, and the set
# -e lives on past the commands we are providing it. Some of the travis
# commands are supposed to exit with non zero status, but then continue
# executing. set -x makes the travis log files extremely verbose and
# difficult to understand.
#
# see travis-ci/travis-ci#5120
set +ex

35
blas/.travis/test-coverage.sh Executable file
View File

@@ -0,0 +1,35 @@
#!/bin/bash
PROFILE_OUT=$PWD/profile.out
ACC_OUT=$PWD/acc.out
testCover() {
# set the return value to 0 (succesful)
retval=0
# get the directory to check from the parameter. Default to '.'
d=${1:-.}
# skip if there are no Go files here
ls $d/*.go &> /dev/null || return $retval
# switch to the directory to check
pushd $d > /dev/null
# create the coverage profile
coverageresult=`go test -v -coverprofile=$PROFILE_OUT`
# output the result so we can check the shell output
echo ${coverageresult}
# append the results to acc.out if coverage didn't fail, else set the retval to 1 (failed)
( [[ ${coverageresult} == *FAIL* ]] && retval=1 ) || ( [ -f $PROFILE_OUT ] && grep -v "mode: set" $PROFILE_OUT >> $ACC_OUT )
# return to our working dir
popd > /dev/null
# return our return value
return $retval
}
# Init acc.out
echo "mode: set" > $ACC_OUT
# Run test coverage on all directories containing go files
find . -maxdepth 10 -type d | while read d; do testCover $d || exit; done
# Upload the coverage profile to coveralls.io
[ -n "$COVERALLS_TOKEN" ] && goveralls -coverprofile=$ACC_OUT -service=travis-ci -repotoken $COVERALLS_TOKEN

96
blas/README.md Normal file
View File

@@ -0,0 +1,96 @@
# Gonum BLAS [![Build Status](https://travis-ci.org/gonum/blas.svg?branch=master)](https://travis-ci.org/gonum/blas) [![Coverage Status](https://coveralls.io/repos/gonum/blas/badge.svg?branch=master&service=github)](https://coveralls.io/github/gonum/blas?branch=master) [![GoDoc](https://godoc.org/github.com/gonum/blas?status.svg)](https://godoc.org/github.com/gonum/blas)
A collection of packages to provide BLAS functionality for the [Go programming
language](http://golang.org)
## Installation
```sh
go get github.com/gonum/blas
```
### BLAS C-bindings
If you want to use OpenBLAS, install it in any directory:
```sh
git clone https://github.com/xianyi/OpenBLAS
cd OpenBLAS
make
```
The blas/cgo package provides bindings to C-backed BLAS packages. blas/cgo needs the `CGO_LDFLAGS`
environment variable to point to the blas installation. More information can be found in the
[cgo command documentation](http://golang.org/cmd/cgo/).
Then install the blas/cgo package:
```sh
CGO_LDFLAGS="-L/path/to/OpenBLAS -lopenblas" go install github.com/gonum/blas/cgo
```
For Windows you can download binary packages for OpenBLAS at
[SourceForge](http://sourceforge.net/projects/openblas/files/).
If you want to use a different BLAS package such as the Intel MKL you can
adjust the `CGO_LDFLAGS` variable:
```sh
CGO_LDFLAGS="-lmkl_rt" go install github.com/gonum/blas/cgo
```
On OS X the easiest solution is to use the libraries provided by the system:
```sh
CGO_LDFLAGS="-framework Accelerate" go install github.com/gonum/blas/cgo
```
## Packages
### blas
Defines [BLAS API](http://www.netlib.org/blas/blast-forum/cinterface.pdf) split in several
interfaces.
### blas/native
Go implementation of the BLAS API (incomplete, implements the `float32` and `float64` API)
### blas/cgo
Binding to a C implementation of the cblas interface (e.g. ATLAS, OpenBLAS, Intel MKL)
The recommended (free) option for good performance on both Linux and Darwin is OpenBLAS.
### blas/blas64 and blas/blas32
Wrappers for an implementation of the double (i.e., `float64`) and single (`float32`)
precision real parts of the blas API
```Go
package main
import (
"fmt"
"github.com/gonum/blas/blas64"
)
func main() {
v := blas64.Vector{Inc: 1, Data: []float64{1, 1, 1}}
fmt.Println("v has length:", blas64.Nrm2(len(v.Data), v))
}
```
### blas/cblas128 and blas/cblas64
Wrappers for an implementation of the double (i.e., `complex128`) and single (`complex64`)
precision complex parts of the blas API
Currently blas/cblas64 and blas/cblas128 require blas/cgo.
## Issues
If you find any bugs, feel free to file an issue on the github issue tracker.
Discussions on API changes, added features, code review, or similar requests
are preferred on the [gonum-dev Google Group](https://groups.google.com/forum/#!forum/gonum-dev).
## License
Please see [github.com/gonum/license](https://github.com/gonum/license) for general
license information, contributors, authors, etc on the Gonum suite of packages.

388
blas/blas.go Normal file
View File

@@ -0,0 +1,388 @@
// Copyright ©2013 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package blas provides interfaces for the BLAS linear algebra standard.
All methods must perform appropriate parameter checking and panic if
provided parameters that do not conform to the requirements specified
by the BLAS standard.
Quick Reference Guide to the BLAS from http://www.netlib.org/lapack/lug/node145.html
This version is modified to remove the "order" option. All matrix operations are
on row-order matrices.
Level 1 BLAS
dim scalar vector vector scalars 5-element prefixes
struct
_rotg ( a, b ) S, D
_rotmg( d1, d2, a, b ) S, D
_rot ( n, x, incX, y, incY, c, s ) S, D
_rotm ( n, x, incX, y, incY, param ) S, D
_swap ( n, x, incX, y, incY ) S, D, C, Z
_scal ( n, alpha, x, incX ) S, D, C, Z, Cs, Zd
_copy ( n, x, incX, y, incY ) S, D, C, Z
_axpy ( n, alpha, x, incX, y, incY ) S, D, C, Z
_dot ( n, x, incX, y, incY ) S, D, Ds
_dotu ( n, x, incX, y, incY ) C, Z
_dotc ( n, x, incX, y, incY ) C, Z
__dot ( n, alpha, x, incX, y, incY ) Sds
_nrm2 ( n, x, incX ) S, D, Sc, Dz
_asum ( n, x, incX ) S, D, Sc, Dz
I_amax( n, x, incX ) s, d, c, z
Level 2 BLAS
options dim b-width scalar matrix vector scalar vector prefixes
_gemv ( trans, m, n, alpha, a, lda, x, incX, beta, y, incY ) S, D, C, Z
_gbmv ( trans, m, n, kL, kU, alpha, a, lda, x, incX, beta, y, incY ) S, D, C, Z
_hemv ( uplo, n, alpha, a, lda, x, incX, beta, y, incY ) C, Z
_hbmv ( uplo, n, k, alpha, a, lda, x, incX, beta, y, incY ) C, Z
_hpmv ( uplo, n, alpha, ap, x, incX, beta, y, incY ) C, Z
_symv ( uplo, n, alpha, a, lda, x, incX, beta, y, incY ) S, D
_sbmv ( uplo, n, k, alpha, a, lda, x, incX, beta, y, incY ) S, D
_spmv ( uplo, n, alpha, ap, x, incX, beta, y, incY ) S, D
_trmv ( uplo, trans, diag, n, a, lda, x, incX ) S, D, C, Z
_tbmv ( uplo, trans, diag, n, k, a, lda, x, incX ) S, D, C, Z
_tpmv ( uplo, trans, diag, n, ap, x, incX ) S, D, C, Z
_trsv ( uplo, trans, diag, n, a, lda, x, incX ) S, D, C, Z
_tbsv ( uplo, trans, diag, n, k, a, lda, x, incX ) S, D, C, Z
_tpsv ( uplo, trans, diag, n, ap, x, incX ) S, D, C, Z
options dim scalar vector vector matrix prefixes
_ger ( m, n, alpha, x, incX, y, incY, a, lda ) S, D
_geru ( m, n, alpha, x, incX, y, incY, a, lda ) C, Z
_gerc ( m, n, alpha, x, incX, y, incY, a, lda ) C, Z
_her ( uplo, n, alpha, x, incX, a, lda ) C, Z
_hpr ( uplo, n, alpha, x, incX, ap ) C, Z
_her2 ( uplo, n, alpha, x, incX, y, incY, a, lda ) C, Z
_hpr2 ( uplo, n, alpha, x, incX, y, incY, ap ) C, Z
_syr ( uplo, n, alpha, x, incX, a, lda ) S, D
_spr ( uplo, n, alpha, x, incX, ap ) S, D
_syr2 ( uplo, n, alpha, x, incX, y, incY, a, lda ) S, D
_spr2 ( uplo, n, alpha, x, incX, y, incY, ap ) S, D
Level 3 BLAS
options dim scalar matrix matrix scalar matrix prefixes
_gemm ( transA, transB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_symm ( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_hemm ( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc ) C, Z
_syrk ( uplo, trans, n, k, alpha, a, lda, beta, c, ldc ) S, D, C, Z
_herk ( uplo, trans, n, k, alpha, a, lda, beta, c, ldc ) C, Z
_syr2k( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) S, D, C, Z
_her2k( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc ) C, Z
_trmm ( side, uplo, transA, diag, m, n, alpha, a, lda, b, ldb ) S, D, C, Z
_trsm ( side, uplo, transA, diag, m, n, alpha, a, lda, b, ldb ) S, D, C, Z
Meaning of prefixes
S - float32 C - complex64
D - float64 Z - complex128
Matrix types
GE - GEneral GB - General Band
SY - SYmmetric SB - Symmetric Band SP - Symmetric Packed
HE - HErmitian HB - Hermitian Band HP - Hermitian Packed
TR - TRiangular TB - Triangular Band TP - Triangular Packed
Options
trans = NoTrans, Trans, ConjTrans
uplo = Upper, Lower
diag = Nonunit, Unit
side = Left, Right (A or op(A) on the left, or A or op(A) on the right)
For real matrices, Trans and ConjTrans have the same meaning.
For Hermitian matrices, trans = Trans is not allowed.
For complex symmetric matrices, trans = ConjTrans is not allowed.
*/
package blas
// Flag constants indicate Givens transformation H matrix state.
type Flag int
const (
Identity Flag = iota - 2 // H is the identity matrix; no rotation is needed.
Rescaling // H specifies rescaling.
OffDiagonal // Off-diagonal elements of H are units.
Diagonal // Diagonal elements of H are units.
)
// SrotmParams contains Givens transformation parameters returned
// by the Float32 Srotm method.
type SrotmParams struct {
Flag
H [4]float32 // Column-major 2 by 2 matrix.
}
// DrotmParams contains Givens transformation parameters returned
// by the Float64 Drotm method.
type DrotmParams struct {
Flag
H [4]float64 // Column-major 2 by 2 matrix.
}
// Transpose is used to specify the transposition operation for a
// routine.
type Transpose int
const (
NoTrans Transpose = 111 + iota
Trans
ConjTrans
)
// Uplo is used to specify whether the matrix is an upper or lower
// triangular matrix.
type Uplo int
const (
All Uplo = 120 + iota
Upper
Lower
)
// Diag is used to specify whether the matrix is a unit or non-unit
// triangular matrix.
type Diag int
const (
NonUnit Diag = 131 + iota
Unit
)
// Side is used to specify from which side a multiplication operation
// is performed.
type Side int
const (
Left Side = 141 + iota
Right
)
// Float32 implements the single precision real BLAS routines.
type Float32 interface {
Float32Level1
Float32Level2
Float32Level3
}
// Float32Level1 implements the single precision real BLAS Level 1 routines.
type Float32Level1 interface {
Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32
Dsdot(n int, x []float32, incX int, y []float32, incY int) float64
Sdot(n int, x []float32, incX int, y []float32, incY int) float32
Snrm2(n int, x []float32, incX int) float32
Sasum(n int, x []float32, incX int) float32
Isamax(n int, x []float32, incX int) int
Sswap(n int, x []float32, incX int, y []float32, incY int)
Scopy(n int, x []float32, incX int, y []float32, incY int)
Saxpy(n int, alpha float32, x []float32, incX int, y []float32, incY int)
Srotg(a, b float32) (c, s, r, z float32)
Srotmg(d1, d2, b1, b2 float32) (p SrotmParams, rd1, rd2, rb1 float32)
Srot(n int, x []float32, incX int, y []float32, incY int, c, s float32)
Srotm(n int, x []float32, incX int, y []float32, incY int, p SrotmParams)
Sscal(n int, alpha float32, x []float32, incX int)
}
// Float32Level2 implements the single precision real BLAS Level 2 routines.
type Float32Level2 interface {
Sgemv(tA Transpose, m, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Sgbmv(tA Transpose, m, n, kL, kU int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Strmv(ul Uplo, tA Transpose, d Diag, n int, a []float32, lda int, x []float32, incX int)
Stbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []float32, lda int, x []float32, incX int)
Stpmv(ul Uplo, tA Transpose, d Diag, n int, ap []float32, x []float32, incX int)
Strsv(ul Uplo, tA Transpose, d Diag, n int, a []float32, lda int, x []float32, incX int)
Stbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []float32, lda int, x []float32, incX int)
Stpsv(ul Uplo, tA Transpose, d Diag, n int, ap []float32, x []float32, incX int)
Ssymv(ul Uplo, n int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Ssbmv(ul Uplo, n, k int, alpha float32, a []float32, lda int, x []float32, incX int, beta float32, y []float32, incY int)
Sspmv(ul Uplo, n int, alpha float32, ap []float32, x []float32, incX int, beta float32, y []float32, incY int)
Sger(m, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int)
Ssyr(ul Uplo, n int, alpha float32, x []float32, incX int, a []float32, lda int)
Sspr(ul Uplo, n int, alpha float32, x []float32, incX int, ap []float32)
Ssyr2(ul Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32, lda int)
Sspr2(ul Uplo, n int, alpha float32, x []float32, incX int, y []float32, incY int, a []float32)
}
// Float32Level3 implements the single precision real BLAS Level 3 routines.
type Float32Level3 interface {
Sgemm(tA, tB Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Ssymm(s Side, ul Uplo, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Ssyrk(ul Uplo, t Transpose, n, k int, alpha float32, a []float32, lda int, beta float32, c []float32, ldc int)
Ssyr2k(ul Uplo, t Transpose, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int)
Strmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int)
Strsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int)
}
// Float64 implements the single precision real BLAS routines.
type Float64 interface {
Float64Level1
Float64Level2
Float64Level3
}
// Float64Level1 implements the double precision real BLAS Level 1 routines.
type Float64Level1 interface {
Ddot(n int, x []float64, incX int, y []float64, incY int) float64
Dnrm2(n int, x []float64, incX int) float64
Dasum(n int, x []float64, incX int) float64
Idamax(n int, x []float64, incX int) int
Dswap(n int, x []float64, incX int, y []float64, incY int)
Dcopy(n int, x []float64, incX int, y []float64, incY int)
Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int)
Drotg(a, b float64) (c, s, r, z float64)
Drotmg(d1, d2, b1, b2 float64) (p DrotmParams, rd1, rd2, rb1 float64)
Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64)
Drotm(n int, x []float64, incX int, y []float64, incY int, p DrotmParams)
Dscal(n int, alpha float64, x []float64, incX int)
}
// Float64Level2 implements the double precision real BLAS Level 2 routines.
type Float64Level2 interface {
Dgemv(tA Transpose, m, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dgbmv(tA Transpose, m, n, kL, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dtrmv(ul Uplo, tA Transpose, d Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpmv(ul Uplo, tA Transpose, d Diag, n int, ap []float64, x []float64, incX int)
Dtrsv(ul Uplo, tA Transpose, d Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpsv(ul Uplo, tA Transpose, d Diag, n int, ap []float64, x []float64, incX int)
Dsymv(ul Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dsbmv(ul Uplo, n, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
Dspmv(ul Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int)
Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
Dsyr(ul Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int)
Dspr(ul Uplo, n int, alpha float64, x []float64, incX int, ap []float64)
Dsyr2(ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
Dspr2(ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64)
}
// Float64Level3 implements the double precision real BLAS Level 3 routines.
type Float64Level3 interface {
Dgemm(tA, tB Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dsymm(s Side, ul Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dsyrk(ul Uplo, t Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int)
Dsyr2k(ul Uplo, t Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
Dtrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
Dtrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
}
// Complex64 implements the single precision complex BLAS routines.
type Complex64 interface {
Complex64Level1
Complex64Level2
Complex64Level3
}
// Complex64Level1 implements the single precision complex BLAS Level 1 routines.
type Complex64Level1 interface {
Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64)
Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64)
Scnrm2(n int, x []complex64, incX int) float32
Scasum(n int, x []complex64, incX int) float32
Icamax(n int, x []complex64, incX int) int
Cswap(n int, x []complex64, incX int, y []complex64, incY int)
Ccopy(n int, x []complex64, incX int, y []complex64, incY int)
Caxpy(n int, alpha complex64, x []complex64, incX int, y []complex64, incY int)
Cscal(n int, alpha complex64, x []complex64, incX int)
Csscal(n int, alpha float32, x []complex64, incX int)
}
// Complex64Level2 implements the single precision complex BLAS routines Level 2 routines.
type Complex64Level2 interface {
Cgemv(tA Transpose, m, n int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Cgbmv(tA Transpose, m, n, kL, kU int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Ctrmv(ul Uplo, tA Transpose, d Diag, n int, a []complex64, lda int, x []complex64, incX int)
Ctbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex64, lda int, x []complex64, incX int)
Ctpmv(ul Uplo, tA Transpose, d Diag, n int, ap []complex64, x []complex64, incX int)
Ctrsv(ul Uplo, tA Transpose, d Diag, n int, a []complex64, lda int, x []complex64, incX int)
Ctbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex64, lda int, x []complex64, incX int)
Ctpsv(ul Uplo, tA Transpose, d Diag, n int, ap []complex64, x []complex64, incX int)
Chemv(ul Uplo, n int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Chbmv(ul Uplo, n, k int, alpha complex64, a []complex64, lda int, x []complex64, incX int, beta complex64, y []complex64, incY int)
Chpmv(ul Uplo, n int, alpha complex64, ap []complex64, x []complex64, incX int, beta complex64, y []complex64, incY int)
Cgeru(m, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Cgerc(m, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Cher(ul Uplo, n int, alpha float32, x []complex64, incX int, a []complex64, lda int)
Chpr(ul Uplo, n int, alpha float32, x []complex64, incX int, a []complex64)
Cher2(ul Uplo, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, a []complex64, lda int)
Chpr2(ul Uplo, n int, alpha complex64, x []complex64, incX int, y []complex64, incY int, ap []complex64)
}
// Complex64Level3 implements the single precision complex BLAS Level 3 routines.
type Complex64Level3 interface {
Cgemm(tA, tB Transpose, m, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Csymm(s Side, ul Uplo, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Csyrk(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, beta complex64, c []complex64, ldc int)
Csyr2k(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Ctrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int)
Ctrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int)
Chemm(s Side, ul Uplo, m, n int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta complex64, c []complex64, ldc int)
Cherk(ul Uplo, t Transpose, n, k int, alpha float32, a []complex64, lda int, beta float32, c []complex64, ldc int)
Cher2k(ul Uplo, t Transpose, n, k int, alpha complex64, a []complex64, lda int, b []complex64, ldb int, beta float32, c []complex64, ldc int)
}
// Complex128 implements the double precision complex BLAS routines.
type Complex128 interface {
Complex128Level1
Complex128Level2
Complex128Level3
}
// Complex128Level1 implements the double precision complex BLAS Level 1 routines.
type Complex128Level1 interface {
Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128)
Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128)
Dznrm2(n int, x []complex128, incX int) float64
Dzasum(n int, x []complex128, incX int) float64
Izamax(n int, x []complex128, incX int) int
Zswap(n int, x []complex128, incX int, y []complex128, incY int)
Zcopy(n int, x []complex128, incX int, y []complex128, incY int)
Zaxpy(n int, alpha complex128, x []complex128, incX int, y []complex128, incY int)
Zscal(n int, alpha complex128, x []complex128, incX int)
Zdscal(n int, alpha float64, x []complex128, incX int)
}
// Complex128Level2 implements the double precision complex BLAS Level 2 routines.
type Complex128Level2 interface {
Zgemv(tA Transpose, m, n int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zgbmv(tA Transpose, m, n int, kL int, kU int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Ztrmv(ul Uplo, tA Transpose, d Diag, n int, a []complex128, lda int, x []complex128, incX int)
Ztbmv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex128, lda int, x []complex128, incX int)
Ztpmv(ul Uplo, tA Transpose, d Diag, n int, ap []complex128, x []complex128, incX int)
Ztrsv(ul Uplo, tA Transpose, d Diag, n int, a []complex128, lda int, x []complex128, incX int)
Ztbsv(ul Uplo, tA Transpose, d Diag, n, k int, a []complex128, lda int, x []complex128, incX int)
Ztpsv(ul Uplo, tA Transpose, d Diag, n int, ap []complex128, x []complex128, incX int)
Zhemv(ul Uplo, n int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zhbmv(ul Uplo, n, k int, alpha complex128, a []complex128, lda int, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zhpmv(ul Uplo, n int, alpha complex128, ap []complex128, x []complex128, incX int, beta complex128, y []complex128, incY int)
Zgeru(m, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zgerc(m, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zher(ul Uplo, n int, alpha float64, x []complex128, incX int, a []complex128, lda int)
Zhpr(ul Uplo, n int, alpha float64, x []complex128, incX int, a []complex128)
Zher2(ul Uplo, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, a []complex128, lda int)
Zhpr2(ul Uplo, n int, alpha complex128, x []complex128, incX int, y []complex128, incY int, ap []complex128)
}
// Complex128Level3 implements the double precision complex BLAS Level 3 routines.
type Complex128Level3 interface {
Zgemm(tA, tB Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zsymm(s Side, ul Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zsyrk(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int)
Zsyr2k(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Ztrmm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
Ztrsm(s Side, ul Uplo, tA Transpose, d Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int)
Zhemm(s Side, ul Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int)
Zherk(ul Uplo, t Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int)
Zher2k(ul Uplo, t Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int)
}

458
blas/blas32/blas32.go Normal file
View File

@@ -0,0 +1,458 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package blas32 provides a simple interface to the float32 BLAS API.
package blas32
import (
"github.com/gonum/blas"
"github.com/gonum/blas/native"
)
var blas32 blas.Float32 = native.Implementation{}
// Use sets the BLAS float32 implementation to be used by subsequent BLAS calls.
// The default implementation is native.Implementation.
func Use(b blas.Float32) {
blas32 = b
}
// Implementation returns the current BLAS float32 implementation.
//
// Implementation allows direct calls to the current the BLAS float32 implementation
// giving finer control of parameters.
func Implementation() blas.Float32 {
return blas32
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []float32
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []float32
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []float32
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []float32
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []float32
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []float32
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []float32
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []float32
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []float32
Uplo blas.Uplo
}
// Level 1
const negInc = "blas32: negative vector increment"
// Dot computes the dot product of the two vectors:
// \sum_i x[i]*y[i].
func Dot(n int, x, y Vector) float32 {
return blas32.Sdot(n, x.Data, x.Inc, y.Data, y.Inc)
}
// DDot computes the dot product of the two vectors:
// \sum_i x[i]*y[i].
func DDot(n int, x, y Vector) float64 {
return blas32.Dsdot(n, x.Data, x.Inc, y.Data, y.Inc)
}
// SDDot computes the dot product of the two vectors adding a constant:
// alpha + \sum_i x[i]*y[i].
func SDDot(n int, alpha float32, x, y Vector) float32 {
return blas32.Sdsdot(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 computes the Euclidean norm of the vector x:
// sqrt(\sum_i x[i]*x[i]).
//
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float32 {
if x.Inc < 0 {
panic(negInc)
}
return blas32.Snrm2(n, x.Data, x.Inc)
}
// Asum computes the sum of the absolute values of the elements of x:
// \sum_i |x[i]|.
//
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float32 {
if x.Inc < 0 {
panic(negInc)
}
return blas32.Sasum(n, x.Data, x.Inc)
}
// Iamax returns the index of an element of x with the largest absolute value.
// If there are multiple such indices the earliest is returned.
// Iamax returns -1 if n == 0.
//
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return blas32.Isamax(n, x.Data, x.Inc)
}
// Swap exchanges the elements of the two vectors:
// x[i], y[i] = y[i], x[i] for all i.
func Swap(n int, x, y Vector) {
blas32.Sswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Copy copies the elements of x into the elements of y:
// y[i] = x[i] for all i.
func Copy(n int, x, y Vector) {
blas32.Scopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Axpy adds x scaled by alpha to y:
// y[i] += alpha*x[i] for all i.
func Axpy(n int, alpha float32, x, y Vector) {
blas32.Saxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Rotg computes the parameters of a Givens plane rotation so that
// ⎡ c s⎤ ⎡a⎤ ⎡r⎤
// ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
// where a and b are the Cartesian coordinates of a given point.
// c, s, and r are defined as
// r = ±Sqrt(a^2 + b^2),
// c = a/r, the cosine of the rotation angle,
// s = a/r, the sine of the rotation angle,
// and z is defined such that
// if |a| > |b|, z = s,
// otherwise if c != 0, z = 1/c,
// otherwise z = 1.
func Rotg(a, b float32) (c, s, r, z float32) {
return blas32.Srotg(a, b)
}
// Rotmg computes the modified Givens rotation. See
// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
// for more details.
func Rotmg(d1, d2, b1, b2 float32) (p blas.SrotmParams, rd1, rd2, rb1 float32) {
return blas32.Srotmg(d1, d2, b1, b2)
}
// Rot applies a plane transformation to n points represented by the vectors x
// and y:
// x[i] = c*x[i] + s*y[i],
// y[i] = -s*x[i] + c*y[i], for all i.
func Rot(n int, x, y Vector, c, s float32) {
blas32.Srot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
}
// Rotm applies the modified Givens rotation to n points represented by the
// vectors x and y.
func Rotm(n int, x, y Vector, p blas.SrotmParams) {
blas32.Srotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
}
// Scal scales the vector x by alpha:
// x[i] *= alpha for all i.
//
// Scal will panic if the vector increment is negative.
func Scal(n int, alpha float32, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
blas32.Sscal(n, alpha, x.Data, x.Inc)
}
// Level 2
// Gemv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
func Gemv(t blas.Transpose, alpha float32, a General, x Vector, beta float32, y Vector) {
blas32.Sgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Gbmv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
func Gbmv(t blas.Transpose, alpha float32, a Band, x Vector, beta float32, y Vector) {
blas32.Sgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Trmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix, and x is a vector.
func Trmv(t blas.Transpose, a Triangular, x Vector) {
blas32.Strmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
blas32.Stbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x is a vector.
func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
blas32.Stpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Trsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix, and x and b are vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Trsv(t blas.Transpose, a Triangular, x Vector) {
blas32.Strsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular band matrix, and x and b are vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
blas32.Stbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x and b are
// vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
blas32.Stpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Symv computes
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
// beta are scalars.
func Symv(alpha float32, a Symmetric, x Vector, beta float32, y Vector) {
blas32.Ssymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Sbmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric band matrix, x and y are vectors, and alpha
// and beta are scalars.
func Sbmv(alpha float32, a SymmetricBand, x Vector, beta float32, y Vector) {
blas32.Ssbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Spmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric matrix in packed format, x and y are vectors,
// and alpha and beta are scalars.
func Spmv(alpha float32, a SymmetricPacked, x Vector, beta float32, y Vector) {
blas32.Sspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Ger performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Ger(alpha float32, x, y Vector, a General) {
blas32.Sger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Syr performs a rank-1 update
// A += alpha * x * x^T,
// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
func Syr(alpha float32, x Vector, a Symmetric) {
blas32.Ssyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
// Spr performs the rank-1 update
// A += alpha * x * x^T,
// where A is an n×n symmetric matrix in packed format, x is a vector, and
// alpha is a scalar.
func Spr(alpha float32, x Vector, a SymmetricPacked) {
blas32.Sspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
// Syr2 performs a rank-2 update
// A += alpha * x * y^T + alpha * y * x^T,
// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
func Syr2(alpha float32, x, y Vector, a Symmetric) {
blas32.Ssyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Spr2 performs a rank-2 update
// A += alpha * x * y^T + alpha * y * x^T,
// where A is an n×n symmetric matrix in packed format, x and y are vectors,
// and alpha is a scalar.
func Spr2(alpha float32, x, y Vector, a SymmetricPacked) {
blas32.Sspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
// Gemm computes
// C = alpha * A * B + beta * C,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed.
func Gemm(tA, tB blas.Transpose, alpha float32, a, b General, beta float32, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
blas32.Sgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Symm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
// alpha is a scalar.
func Symm(s blas.Side, alpha float32, a Symmetric, b General, beta float32, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
blas32.Ssymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Syrk performs a symmetric rank-k update
// C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
// a k×n matrix otherwise, and alpha and beta are scalars.
func Syrk(t blas.Transpose, alpha float32, a General, beta float32, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas32.Ssyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Syr2k performs a symmetric rank-2k update
// C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
// and k×n matrices otherwise, and alpha and beta are scalars.
func Syr2k(t blas.Transpose, alpha float32, a, b General, beta float32, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas32.Ssyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Trmm performs
// B = alpha * A * B, if tA == blas.NoTrans and s == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and s == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
// a scalar.
func Trmm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
blas32.Strmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Trsm solves
// A * X = alpha * B, if tA == blas.NoTrans and s == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans and s == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
// alpha is a scalar.
//
// At entry to the function, X contains the values of B, and the result is
// stored in-place into X.
//
// No check is made that A is invertible.
func Trsm(s blas.Side, tA blas.Transpose, alpha float32, a Triangular, b General) {
blas32.Strsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}

446
blas/blas64/blas64.go Normal file
View File

@@ -0,0 +1,446 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package blas64 provides a simple interface to the float64 BLAS API.
package blas64
import (
"github.com/gonum/blas"
"github.com/gonum/blas/native"
)
var blas64 blas.Float64 = native.Implementation{}
// Use sets the BLAS float64 implementation to be used by subsequent BLAS calls.
// The default implementation is native.Implementation.
func Use(b blas.Float64) {
blas64 = b
}
// Implementation returns the current BLAS float64 implementation.
//
// Implementation allows direct calls to the current the BLAS float64 implementation
// giving finer control of parameters.
func Implementation() blas.Float64 {
return blas64
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []float64
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []float64
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []float64
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []float64
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []float64
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []float64
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []float64
Uplo blas.Uplo
}
// Level 1
const negInc = "blas64: negative vector increment"
// Dot computes the dot product of the two vectors:
// \sum_i x[i]*y[i].
func Dot(n int, x, y Vector) float64 {
return blas64.Ddot(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 computes the Euclidean norm of the vector x:
// sqrt(\sum_i x[i]*x[i]).
//
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Dnrm2(n, x.Data, x.Inc)
}
// Asum computes the sum of the absolute values of the elements of x:
// \sum_i |x[i]|.
//
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Dasum(n, x.Data, x.Inc)
}
// Iamax returns the index of an element of x with the largest absolute value.
// If there are multiple such indices the earliest is returned.
// Iamax returns -1 if n == 0.
//
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return blas64.Idamax(n, x.Data, x.Inc)
}
// Swap exchanges the elements of the two vectors:
// x[i], y[i] = y[i], x[i] for all i.
func Swap(n int, x, y Vector) {
blas64.Dswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Copy copies the elements of x into the elements of y:
// y[i] = x[i] for all i.
func Copy(n int, x, y Vector) {
blas64.Dcopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Axpy adds x scaled by alpha to y:
// y[i] += alpha*x[i] for all i.
func Axpy(n int, alpha float64, x, y Vector) {
blas64.Daxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Rotg computes the parameters of a Givens plane rotation so that
// ⎡ c s⎤ ⎡a⎤ ⎡r⎤
// ⎣-s c⎦ * ⎣b⎦ = ⎣0⎦
// where a and b are the Cartesian coordinates of a given point.
// c, s, and r are defined as
// r = ±Sqrt(a^2 + b^2),
// c = a/r, the cosine of the rotation angle,
// s = a/r, the sine of the rotation angle,
// and z is defined such that
// if |a| > |b|, z = s,
// otherwise if c != 0, z = 1/c,
// otherwise z = 1.
func Rotg(a, b float64) (c, s, r, z float64) {
return blas64.Drotg(a, b)
}
// Rotmg computes the modified Givens rotation. See
// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
// for more details.
func Rotmg(d1, d2, b1, b2 float64) (p blas.DrotmParams, rd1, rd2, rb1 float64) {
return blas64.Drotmg(d1, d2, b1, b2)
}
// Rot applies a plane transformation to n points represented by the vectors x
// and y:
// x[i] = c*x[i] + s*y[i],
// y[i] = -s*x[i] + c*y[i], for all i.
func Rot(n int, x, y Vector, c, s float64) {
blas64.Drot(n, x.Data, x.Inc, y.Data, y.Inc, c, s)
}
// Rotm applies the modified Givens rotation to n points represented by the
// vectors x and y.
func Rotm(n int, x, y Vector, p blas.DrotmParams) {
blas64.Drotm(n, x.Data, x.Inc, y.Data, y.Inc, p)
}
// Scal scales the vector x by alpha:
// x[i] *= alpha for all i.
//
// Scal will panic if the vector increment is negative.
func Scal(n int, alpha float64, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
blas64.Dscal(n, alpha, x.Data, x.Inc)
}
// Level 2
// Gemv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are scalars.
func Gemv(t blas.Transpose, alpha float64, a General, x Vector, beta float64, y Vector) {
blas64.Dgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Gbmv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans or blas.ConjTrans,
// where A is an m×n band matrix, x and y are vectors, and alpha and beta are scalars.
func Gbmv(t blas.Transpose, alpha float64, a Band, x Vector, beta float64, y Vector) {
blas64.Dgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Trmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix, and x is a vector.
func Trmv(t blas.Transpose, a Triangular, x Vector) {
blas64.Dtrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
blas64.Dtbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x is a vector.
func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
blas64.Dtpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Trsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix, and x and b are vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Trsv(t blas.Transpose, a Triangular, x Vector) {
blas64.Dtrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular band matrix, and x and b are vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
blas64.Dtbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans or blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x and b are
// vectors.
//
// At entry to the function, x contains the values of b, and the result is
// stored in place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
blas64.Dtpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Symv computes
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric matrix, x and y are vectors, and alpha and
// beta are scalars.
func Symv(alpha float64, a Symmetric, x Vector, beta float64, y Vector) {
blas64.Dsymv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Sbmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric band matrix, x and y are vectors, and alpha
// and beta are scalars.
func Sbmv(alpha float64, a SymmetricBand, x Vector, beta float64, y Vector) {
blas64.Dsbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Spmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n symmetric matrix in packed format, x and y are vectors,
// and alpha and beta are scalars.
func Spmv(alpha float64, a SymmetricPacked, x Vector, beta float64, y Vector) {
blas64.Dspmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Ger performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Ger(alpha float64, x, y Vector, a General) {
blas64.Dger(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Syr performs a rank-1 update
// A += alpha * x * x^T,
// where A is an n×n symmetric matrix, x is a vector, and alpha is a scalar.
func Syr(alpha float64, x Vector, a Symmetric) {
blas64.Dsyr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
// Spr performs the rank-1 update
// A += alpha * x * x^T,
// where A is an n×n symmetric matrix in packed format, x is a vector, and
// alpha is a scalar.
func Spr(alpha float64, x Vector, a SymmetricPacked) {
blas64.Dspr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
// Syr2 performs a rank-2 update
// A += alpha * x * y^T + alpha * y * x^T,
// where A is a symmetric n×n matrix, x and y are vectors, and alpha is a scalar.
func Syr2(alpha float64, x, y Vector, a Symmetric) {
blas64.Dsyr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Spr2 performs a rank-2 update
// A += alpha * x * y^T + alpha * y * x^T,
// where A is an n×n symmetric matrix in packed format, x and y are vectors,
// and alpha is a scalar.
func Spr2(alpha float64, x, y Vector, a SymmetricPacked) {
blas64.Dspr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
// Gemm computes
// C = alpha * A * B + beta * C,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed.
func Gemm(tA, tB blas.Transpose, alpha float64, a, b General, beta float64, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
blas64.Dgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Symm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
// alpha is a scalar.
func Symm(s blas.Side, alpha float64, a Symmetric, b General, beta float64, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
blas64.Dsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Syrk performs a symmetric rank-k update
// C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans and
// a k×n matrix otherwise, and alpha and beta are scalars.
func Syrk(t blas.Transpose, alpha float64, a General, beta float64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas64.Dsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Syr2k performs a symmetric rank-2k update
// C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans or blas.ConjTrans,
// where C is an n×n symmetric matrix, A and B are n×k matrices if t == NoTrans
// and k×n matrices otherwise, and alpha and beta are scalars.
func Syr2k(t blas.Transpose, alpha float64, a, b General, beta float64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
blas64.Dsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Trmm performs
// B = alpha * A * B, if tA == blas.NoTrans and s == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and s == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
// a scalar.
func Trmm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
blas64.Dtrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Trsm solves
// A * X = alpha * B, if tA == blas.NoTrans and s == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans and s == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and s == blas.Right,
// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
// alpha is a scalar.
//
// At entry to the function, X contains the values of B, and the result is
// stored in-place into X.
//
// No check is made that A is invertible.
func Trsm(s blas.Side, tA blas.Transpose, alpha float64, a Triangular, b General) {
blas64.Dtrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}

510
blas/cblas128/cblas128.go Normal file
View File

@@ -0,0 +1,510 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cblas128 provides a simple interface to the complex128 BLAS API.
package cblas128
import (
"github.com/gonum/blas"
"github.com/gonum/blas/cgo"
)
// TODO(kortschak): Change this and the comment below to native.Implementation
// when blas/native covers the complex BLAS API.
var cblas128 blas.Complex128 = cgo.Implementation{}
// Use sets the BLAS complex128 implementation to be used by subsequent BLAS calls.
// The default implementation is cgo.Implementation.
func Use(b blas.Complex128) {
cblas128 = b
}
// Implementation returns the current BLAS complex128 implementation.
//
// Implementation allows direct calls to the current the BLAS complex128 implementation
// giving finer control of parameters.
func Implementation() blas.Complex128 {
return cblas128
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []complex128
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []complex128
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []complex128
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []complex128
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []complex128
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []complex128
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []complex128
Uplo blas.Uplo
}
// Hermitian represents an Hermitian matrix using the conventional storage scheme.
type Hermitian Symmetric
// HermitianBand represents an Hermitian matrix using the band storage scheme.
type HermitianBand SymmetricBand
// HermitianPacked represents an Hermitian matrix using the packed storage scheme.
type HermitianPacked SymmetricPacked
// Level 1
const negInc = "cblas128: negative vector increment"
// Dotu computes the dot product of the two vectors without
// complex conjugation:
// x^T * y.
func Dotu(n int, x, y Vector) complex128 {
return cblas128.Zdotu(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Dotc computes the dot product of the two vectors with
// complex conjugation:
// x^H * y.
func Dotc(n int, x, y Vector) complex128 {
return cblas128.Zdotc(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 computes the Euclidean norm of the vector x:
// sqrt(\sum_i x[i] * x[i]).
//
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Dznrm2(n, x.Data, x.Inc)
}
// Asum computes the sum of magnitudes of the real and imaginary parts of
// elements of the vector x:
// \sum_i (|Re x[i]| + |Im x[i]|).
//
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float64 {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Dzasum(n, x.Data, x.Inc)
}
// Iamax returns the index of an element of x with the largest sum of
// magnitudes of the real and imaginary parts (|Re x[i]|+|Im x[i]|).
// If there are multiple such indices, the earliest is returned.
//
// Iamax returns -1 if n == 0.
//
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return cblas128.Izamax(n, x.Data, x.Inc)
}
// Swap exchanges the elements of two vectors:
// x[i], y[i] = y[i], x[i] for all i.
func Swap(n int, x, y Vector) {
cblas128.Zswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Copy copies the elements of x into the elements of y:
// y[i] = x[i] for all i.
func Copy(n int, x, y Vector) {
cblas128.Zcopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Axpy computes
// y = alpha * x + y,
// where x and y are vectors, and alpha is a scalar.
func Axpy(n int, alpha complex128, x, y Vector) {
cblas128.Zaxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Scal computes
// x = alpha * x,
// where x is a vector, and alpha is a scalar.
//
// Scal will panic if the vector increment is negative.
func Scal(n int, alpha complex128, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas128.Zscal(n, alpha, x.Data, x.Inc)
}
// Dscal computes
// x = alpha * x,
// where x is a vector, and alpha is a real scalar.
//
// Dscal will panic if the vector increment is negative.
func Dscal(n int, alpha float64, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas128.Zdscal(n, alpha, x.Data, x.Inc)
}
// Level 2
// Gemv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans,
// y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are
// scalars.
func Gemv(t blas.Transpose, alpha complex128, a General, x Vector, beta complex128, y Vector) {
cblas128.Zgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Gbmv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans,
// y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
// where A is an m×n band matrix, x and y are vectors, and alpha and beta are
// scalars.
func Gbmv(t blas.Transpose, alpha complex128, a Band, x Vector, beta complex128, y Vector) {
cblas128.Zgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Trmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular matrix, and x is a vector.
func Trmv(t blas.Transpose, a Triangular, x Vector) {
cblas128.Ztrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
cblas128.Ztbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x is a vector.
func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
cblas128.Ztpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Trsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular matrix and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Trsv(t blas.Transpose, a Triangular, x Vector) {
cblas128.Ztrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
cblas128.Ztbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular matrix in packed format and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
cblas128.Ztpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Hemv computes
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian matrix, x and y are vectors, and alpha and
// beta are scalars.
func Hemv(alpha complex128, a Hermitian, x Vector, beta complex128, y Vector) {
cblas128.Zhemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Hbmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian band matrix, x and y are vectors, and alpha
// and beta are scalars.
func Hbmv(alpha complex128, a HermitianBand, x Vector, beta complex128, y Vector) {
cblas128.Zhbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Hpmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
// and alpha and beta are scalars.
func Hpmv(alpha complex128, a HermitianPacked, x Vector, beta complex128, y Vector) {
cblas128.Zhpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Geru performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Geru(alpha complex128, x, y Vector, a General) {
cblas128.Zgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Gerc performs a rank-1 update
// A += alpha * x * y^H,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Gerc(alpha complex128, x, y Vector, a General) {
cblas128.Zgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Her performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
func Her(alpha float64, x Vector, a Hermitian) {
cblas128.Zher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
// Hpr performs a rank-1 update
// A += alpha * x * x^H,
// where A is an n×n Hermitian matrix in packed format, x is a vector, and
// alpha is a scalar.
func Hpr(alpha float64, x Vector, a HermitianPacked) {
cblas128.Zhpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
// Her2 performs a rank-2 update
// A += alpha * x * y^H + conj(alpha) * y * x^H,
// where A is an n×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
func Her2(alpha complex128, x, y Vector, a Hermitian) {
cblas128.Zher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Hpr2 performs a rank-2 update
// A += alpha * x * y^H + conj(alpha) * y * x^H,
// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
// and alpha is a scalar.
func Hpr2(alpha complex128, x, y Vector, a HermitianPacked) {
cblas128.Zhpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
// Gemm computes
// C = alpha * A * B + beta * C,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed or conjugated.
func Gemm(tA, tB blas.Transpose, alpha complex128, a, b General, beta complex128, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
cblas128.Zgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Symm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
// alpha and beta are scalars.
func Symm(s blas.Side, alpha complex128, a Symmetric, b General, beta complex128, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas128.Zsymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Syrk performs a symmetric rank-k update
// C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * A + beta * C, if t == blas.Trans,
// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans
// and a k×n matrix otherwise, and alpha and beta are scalars.
func Syrk(t blas.Transpose, alpha complex128, a General, beta complex128, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zsyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Syr2k performs a symmetric rank-2k update
// C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans,
// where C is an n×n symmetric matrix, A and B are n×k matrices if
// t == blas.NoTrans and k×n otherwise, and alpha and beta are scalars.
func Syr2k(t blas.Transpose, alpha complex128, a, b General, beta complex128, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zsyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Trmm performs
// B = alpha * A * B, if tA == blas.NoTrans and s == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans and s == blas.Left,
// B = alpha * A^H * B, if tA == blas.ConjTrans and s == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and s == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans and s == blas.Right,
// B = alpha * B * A^H, if tA == blas.ConjTrans and s == blas.Right,
// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
// a scalar.
func Trmm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
cblas128.Ztrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Trsm solves
// A * X = alpha * B, if tA == blas.NoTrans and s == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans and s == blas.Left,
// A^H * X = alpha * B, if tA == blas.ConjTrans and s == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans and s == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans and s == blas.Right,
// X * A^H = alpha * B, if tA == blas.ConjTrans and s == blas.Right,
// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
// alpha is a scalar.
//
// At entry to the function, b contains the values of B, and the result is
// stored in-place into b.
//
// No check is made that A is invertible.
func Trsm(s blas.Side, tA blas.Transpose, alpha complex128, a Triangular, b General) {
cblas128.Ztrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Hemm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m Hermitian matrix, B and C are m×n matrices, and
// alpha and beta are scalars.
func Hemm(s blas.Side, alpha complex128, a Hermitian, b General, beta complex128, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas128.Zhemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Herk performs the Hermitian rank-k update
// C = alpha * A * A^H + beta*C, if t == blas.NoTrans,
// C = alpha * A^H * A + beta*C, if t == blas.ConjTrans,
// where C is an n×n Hermitian matrix, A is an n×k matrix if t == blas.NoTrans
// and a k×n matrix otherwise, and alpha and beta are scalars.
func Herk(t blas.Transpose, alpha float64, a General, beta float64, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Her2k performs the Hermitian rank-2k update
// C = alpha * A * B^H + conj(alpha) * B * A^H + beta * C, if t == blas.NoTrans,
// C = alpha * A^H * B + conj(alpha) * B^H * A + beta * C, if t == blas.ConjTrans,
// where C is an n×n Hermitian matrix, A and B are n×k matrices if t == NoTrans
// and k×n matrices otherwise, and alpha and beta are scalars.
func Her2k(t blas.Transpose, alpha complex128, a, b General, beta float64, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas128.Zher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}

510
blas/cblas64/cblas64.go Normal file
View File

@@ -0,0 +1,510 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package cblas64 provides a simple interface to the complex64 BLAS API.
package cblas64
import (
"github.com/gonum/blas"
"github.com/gonum/blas/cgo"
)
// TODO(kortschak): Change this and the comment below to native.Implementation
// when blas/native covers the complex BLAS API.
var cblas64 blas.Complex64 = cgo.Implementation{}
// Use sets the BLAS complex64 implementation to be used by subsequent BLAS calls.
// The default implementation is cgo.Implementation.
func Use(b blas.Complex64) {
cblas64 = b
}
// Implementation returns the current BLAS complex64 implementation.
//
// Implementation allows direct calls to the current the BLAS complex64 implementation
// giving finer control of parameters.
func Implementation() blas.Complex64 {
return cblas64
}
// Vector represents a vector with an associated element increment.
type Vector struct {
Inc int
Data []complex64
}
// General represents a matrix using the conventional storage scheme.
type General struct {
Rows, Cols int
Stride int
Data []complex64
}
// Band represents a band matrix using the band storage scheme.
type Band struct {
Rows, Cols int
KL, KU int
Stride int
Data []complex64
}
// Triangular represents a triangular matrix using the conventional storage scheme.
type Triangular struct {
N int
Stride int
Data []complex64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularBand represents a triangular matrix using the band storage scheme.
type TriangularBand struct {
N, K int
Stride int
Data []complex64
Uplo blas.Uplo
Diag blas.Diag
}
// TriangularPacked represents a triangular matrix using the packed storage scheme.
type TriangularPacked struct {
N int
Data []complex64
Uplo blas.Uplo
Diag blas.Diag
}
// Symmetric represents a symmetric matrix using the conventional storage scheme.
type Symmetric struct {
N int
Stride int
Data []complex64
Uplo blas.Uplo
}
// SymmetricBand represents a symmetric matrix using the band storage scheme.
type SymmetricBand struct {
N, K int
Stride int
Data []complex64
Uplo blas.Uplo
}
// SymmetricPacked represents a symmetric matrix using the packed storage scheme.
type SymmetricPacked struct {
N int
Data []complex64
Uplo blas.Uplo
}
// Hermitian represents an Hermitian matrix using the conventional storage scheme.
type Hermitian Symmetric
// HermitianBand represents an Hermitian matrix using the band storage scheme.
type HermitianBand SymmetricBand
// HermitianPacked represents an Hermitian matrix using the packed storage scheme.
type HermitianPacked SymmetricPacked
// Level 1
const negInc = "cblas64: negative vector increment"
// Dotu computes the dot product of the two vectors without
// complex conjugation:
// x^T * y
func Dotu(n int, x, y Vector) complex64 {
return cblas64.Cdotu(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Dotc computes the dot product of the two vectors with
// complex conjugation:
// x^H * y.
func Dotc(n int, x, y Vector) complex64 {
return cblas64.Cdotc(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Nrm2 computes the Euclidean norm of the vector x:
// sqrt(\sum_i x[i] * x[i]).
//
// Nrm2 will panic if the vector increment is negative.
func Nrm2(n int, x Vector) float32 {
if x.Inc < 0 {
panic(negInc)
}
return cblas64.Scnrm2(n, x.Data, x.Inc)
}
// Asum computes the sum of magnitudes of the real and imaginary parts of
// elements of the vector x:
// \sum_i (|Re x[i]| + |Im x[i]|).
//
// Asum will panic if the vector increment is negative.
func Asum(n int, x Vector) float32 {
if x.Inc < 0 {
panic(negInc)
}
return cblas64.Scasum(n, x.Data, x.Inc)
}
// Iamax returns the index of an element of x with the largest sum of
// magnitudes of the real and imaginary parts (|Re x[i]|+|Im x[i]|).
// If there are multiple such indices, the earliest is returned.
//
// Iamax returns -1 if n == 0.
//
// Iamax will panic if the vector increment is negative.
func Iamax(n int, x Vector) int {
if x.Inc < 0 {
panic(negInc)
}
return cblas64.Icamax(n, x.Data, x.Inc)
}
// Swap exchanges the elements of two vectors:
// x[i], y[i] = y[i], x[i] for all i.
func Swap(n int, x, y Vector) {
cblas64.Cswap(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Copy copies the elements of x into the elements of y:
// y[i] = x[i] for all i.
func Copy(n int, x, y Vector) {
cblas64.Ccopy(n, x.Data, x.Inc, y.Data, y.Inc)
}
// Axpy computes
// y = alpha * x + y,
// where x and y are vectors, and alpha is a scalar.
func Axpy(n int, alpha complex64, x, y Vector) {
cblas64.Caxpy(n, alpha, x.Data, x.Inc, y.Data, y.Inc)
}
// Scal computes
// x = alpha * x,
// where x is a vector, and alpha is a scalar.
//
// Scal will panic if the vector increment is negative.
func Scal(n int, alpha complex64, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas64.Cscal(n, alpha, x.Data, x.Inc)
}
// Dscal computes
// x = alpha * x,
// where x is a vector, and alpha is a real scalar.
//
// Dscal will panic if the vector increment is negative.
func Dscal(n int, alpha float32, x Vector) {
if x.Inc < 0 {
panic(negInc)
}
cblas64.Csscal(n, alpha, x.Data, x.Inc)
}
// Level 2
// Gemv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans,
// y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
// where A is an m×n dense matrix, x and y are vectors, and alpha and beta are
// scalars.
func Gemv(t blas.Transpose, alpha complex64, a General, x Vector, beta complex64, y Vector) {
cblas64.Cgemv(t, a.Rows, a.Cols, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Gbmv computes
// y = alpha * A * x + beta * y, if t == blas.NoTrans,
// y = alpha * A^T * x + beta * y, if t == blas.Trans,
// y = alpha * A^H * x + beta * y, if t == blas.ConjTrans,
// where A is an m×n band matrix, x and y are vectors, and alpha and beta are
// scalars.
func Gbmv(t blas.Transpose, alpha complex64, a Band, x Vector, beta complex64, y Vector) {
cblas64.Cgbmv(t, a.Rows, a.Cols, a.KL, a.KU, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Trmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular matrix, and x is a vector.
func Trmv(t blas.Transpose, a Triangular, x Vector) {
cblas64.Ctrmv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
func Tbmv(t blas.Transpose, a TriangularBand, x Vector) {
cblas64.Ctbmv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpmv computes
// x = A * x, if t == blas.NoTrans,
// x = A^T * x, if t == blas.Trans,
// x = A^H * x, if t == blas.ConjTrans,
// where A is an n×n triangular matrix in packed format, and x is a vector.
func Tpmv(t blas.Transpose, a TriangularPacked, x Vector) {
cblas64.Ctpmv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Trsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular matrix and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Trsv(t blas.Transpose, a Triangular, x Vector) {
cblas64.Ctrsv(a.Uplo, t, a.Diag, a.N, a.Data, a.Stride, x.Data, x.Inc)
}
// Tbsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular band matrix, and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tbsv(t blas.Transpose, a TriangularBand, x Vector) {
cblas64.Ctbsv(a.Uplo, t, a.Diag, a.N, a.K, a.Data, a.Stride, x.Data, x.Inc)
}
// Tpsv solves
// A * x = b, if t == blas.NoTrans,
// A^T * x = b, if t == blas.Trans,
// A^H * x = b, if t == blas.ConjTrans,
// where A is an n×n triangular matrix in packed format and x is a vector.
//
// At entry to the function, x contains the values of b, and the result is
// stored in-place into x.
//
// No test for singularity or near-singularity is included in this
// routine. Such tests must be performed before calling this routine.
func Tpsv(t blas.Transpose, a TriangularPacked, x Vector) {
cblas64.Ctpsv(a.Uplo, t, a.Diag, a.N, a.Data, x.Data, x.Inc)
}
// Hemv computes
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian matrix, x and y are vectors, and alpha and
// beta are scalars.
func Hemv(alpha complex64, a Hermitian, x Vector, beta complex64, y Vector) {
cblas64.Chemv(a.Uplo, a.N, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Hbmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian band matrix, x and y are vectors, and alpha
// and beta are scalars.
func Hbmv(alpha complex64, a HermitianBand, x Vector, beta complex64, y Vector) {
cblas64.Chbmv(a.Uplo, a.N, a.K, alpha, a.Data, a.Stride, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Hpmv performs
// y = alpha * A * x + beta * y,
// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
// and alpha and beta are scalars.
func Hpmv(alpha complex64, a HermitianPacked, x Vector, beta complex64, y Vector) {
cblas64.Chpmv(a.Uplo, a.N, alpha, a.Data, x.Data, x.Inc, beta, y.Data, y.Inc)
}
// Geru performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Geru(alpha complex64, x, y Vector, a General) {
cblas64.Cgeru(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Gerc performs a rank-1 update
// A += alpha * x * y^H,
// where A is an m×n dense matrix, x and y are vectors, and alpha is a scalar.
func Gerc(alpha complex64, x, y Vector, a General) {
cblas64.Cgerc(a.Rows, a.Cols, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Her performs a rank-1 update
// A += alpha * x * y^T,
// where A is an m×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
func Her(alpha float32, x Vector, a Hermitian) {
cblas64.Cher(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data, a.Stride)
}
// Hpr performs a rank-1 update
// A += alpha * x * x^H,
// where A is an n×n Hermitian matrix in packed format, x is a vector, and
// alpha is a scalar.
func Hpr(alpha float32, x Vector, a HermitianPacked) {
cblas64.Chpr(a.Uplo, a.N, alpha, x.Data, x.Inc, a.Data)
}
// Her2 performs a rank-2 update
// A += alpha * x * y^H + conj(alpha) * y * x^H,
// where A is an n×n Hermitian matrix, x and y are vectors, and alpha is a scalar.
func Her2(alpha complex64, x, y Vector, a Hermitian) {
cblas64.Cher2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data, a.Stride)
}
// Hpr2 performs a rank-2 update
// A += alpha * x * y^H + conj(alpha) * y * x^H,
// where A is an n×n Hermitian matrix in packed format, x and y are vectors,
// and alpha is a scalar.
func Hpr2(alpha complex64, x, y Vector, a HermitianPacked) {
cblas64.Chpr2(a.Uplo, a.N, alpha, x.Data, x.Inc, y.Data, y.Inc, a.Data)
}
// Level 3
// Gemm computes
// C = alpha * A * B + beta * C,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed or conjugated.
func Gemm(tA, tB blas.Transpose, alpha complex64, a, b General, beta complex64, c General) {
var m, n, k int
if tA == blas.NoTrans {
m, k = a.Rows, a.Cols
} else {
m, k = a.Cols, a.Rows
}
if tB == blas.NoTrans {
n = b.Cols
} else {
n = b.Rows
}
cblas64.Cgemm(tA, tB, m, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Symm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and
// alpha and beta are scalars.
func Symm(s blas.Side, alpha complex64, a Symmetric, b General, beta complex64, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas64.Csymm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Syrk performs a symmetric rank-k update
// C = alpha * A * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * A + beta * C, if t == blas.Trans,
// where C is an n×n symmetric matrix, A is an n×k matrix if t == blas.NoTrans
// and a k×n matrix otherwise, and alpha and beta are scalars.
func Syrk(t blas.Transpose, alpha complex64, a General, beta complex64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas64.Csyrk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Syr2k performs a symmetric rank-2k update
// C = alpha * A * B^T + alpha * B * A^T + beta * C, if t == blas.NoTrans,
// C = alpha * A^T * B + alpha * B^T * A + beta * C, if t == blas.Trans,
// where C is an n×n symmetric matrix, A and B are n×k matrices if
// t == blas.NoTrans and k×n otherwise, and alpha and beta are scalars.
func Syr2k(t blas.Transpose, alpha complex64, a, b General, beta complex64, c Symmetric) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas64.Csyr2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Trmm performs
// B = alpha * A * B, if tA == blas.NoTrans and s == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans and s == blas.Left,
// B = alpha * A^H * B, if tA == blas.ConjTrans and s == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and s == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans and s == blas.Right,
// B = alpha * B * A^H, if tA == blas.ConjTrans and s == blas.Right,
// where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is
// a scalar.
func Trmm(s blas.Side, tA blas.Transpose, alpha complex64, a Triangular, b General) {
cblas64.Ctrmm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Trsm solves
// A * X = alpha * B, if tA == blas.NoTrans and s == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans and s == blas.Left,
// A^H * X = alpha * B, if tA == blas.ConjTrans and s == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans and s == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans and s == blas.Right,
// X * A^H = alpha * B, if tA == blas.ConjTrans and s == blas.Right,
// where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and
// alpha is a scalar.
//
// At entry to the function, b contains the values of B, and the result is
// stored in-place into b.
//
// No check is made that A is invertible.
func Trsm(s blas.Side, tA blas.Transpose, alpha complex64, a Triangular, b General) {
cblas64.Ctrsm(s, a.Uplo, tA, a.Diag, b.Rows, b.Cols, alpha, a.Data, a.Stride, b.Data, b.Stride)
}
// Hemm performs
// C = alpha * A * B + beta * C, if s == blas.Left,
// C = alpha * B * A + beta * C, if s == blas.Right,
// where A is an n×n or m×m Hermitian matrix, B and C are m×n matrices, and
// alpha and beta are scalars.
func Hemm(s blas.Side, alpha complex64, a Hermitian, b General, beta complex64, c General) {
var m, n int
if s == blas.Left {
m, n = a.N, b.Cols
} else {
m, n = b.Rows, a.N
}
cblas64.Chemm(s, a.Uplo, m, n, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}
// Herk performs the Hermitian rank-k update
// C = alpha * A * A^H + beta*C, if t == blas.NoTrans,
// C = alpha * A^H * A + beta*C, if t == blas.ConjTrans,
// where C is an n×n Hermitian matrix, A is an n×k matrix if t == blas.NoTrans
// and a k×n matrix otherwise, and alpha and beta are scalars.
func Herk(t blas.Transpose, alpha float32, a General, beta float32, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas64.Cherk(c.Uplo, t, n, k, alpha, a.Data, a.Stride, beta, c.Data, c.Stride)
}
// Her2k performs the Hermitian rank-2k update
// C = alpha * A * B^H + conj(alpha) * B * A^H + beta * C, if t == blas.NoTrans,
// C = alpha * A^H * B + conj(alpha) * B^H * A + beta * C, if t == blas.ConjTrans,
// where C is an n×n Hermitian matrix, A and B are n×k matrices if t == NoTrans
// and k×n matrices otherwise, and alpha and beta are scalars.
func Her2k(t blas.Transpose, alpha complex64, a, b General, beta float32, c Hermitian) {
var n, k int
if t == blas.NoTrans {
n, k = a.Rows, a.Cols
} else {
n, k = a.Cols, a.Rows
}
cblas64.Cher2k(c.Uplo, t, n, k, alpha, a.Data, a.Stride, b.Data, b.Stride, beta, c.Data, c.Stride)
}

18
blas/cgo/bench_test.go Normal file
View File

@@ -0,0 +1,18 @@
package cgo
import (
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
const (
Sm = testblas.SmallMat
Med = testblas.MediumMat
Lg = testblas.LargeMat
Hg = testblas.HugeMat
)
const (
T = blas.Trans
NT = blas.NoTrans
)

4258
blas/cgo/blas.go Normal file

File diff suppressed because it is too large Load Diff

596
blas/cgo/cblas.h Normal file
View File

@@ -0,0 +1,596 @@
#ifndef CBLAS_H
#ifndef CBLAS_ENUM_DEFINED_H
#define CBLAS_ENUM_DEFINED_H
enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 };
enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113,
AtlasConj=114};
enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
#endif
#ifndef CBLAS_ENUM_ONLY
#define CBLAS_H
#define CBLAS_INDEX int
int cblas_errprn(int ierr, int info, char *form, ...);
/*
* ===========================================================================
* Prototypes for level 1 BLAS functions (complex are recast as routines)
* ===========================================================================
*/
float cblas_sdsdot(const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY);
double cblas_dsdot(const int N, const float *X, const int incX, const float *Y,
const int incY);
float cblas_sdot(const int N, const float *X, const int incX,
const float *Y, const int incY);
double cblas_ddot(const int N, const double *X, const int incX,
const double *Y, const int incY);
/*
* Functions having prefixes Z and C only
*/
void cblas_cdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_cdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
void cblas_zdotu_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotu);
void cblas_zdotc_sub(const int N, const void *X, const int incX,
const void *Y, const int incY, void *dotc);
/*
* Functions having prefixes S D SC DZ
*/
float cblas_snrm2(const int N, const float *X, const int incX);
float cblas_sasum(const int N, const float *X, const int incX);
double cblas_dnrm2(const int N, const double *X, const int incX);
double cblas_dasum(const int N, const double *X, const int incX);
float cblas_scnrm2(const int N, const void *X, const int incX);
float cblas_scasum(const int N, const void *X, const int incX);
double cblas_dznrm2(const int N, const void *X, const int incX);
double cblas_dzasum(const int N, const void *X, const int incX);
/*
* Functions having standard 4 prefixes (S D C Z)
*/
CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX);
CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX);
CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX);
CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX);
/*
* ===========================================================================
* Prototypes for level 1 BLAS routines
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (s, d, c, z)
*/
void cblas_sswap(const int N, float *X, const int incX,
float *Y, const int incY);
void cblas_scopy(const int N, const float *X, const int incX,
float *Y, const int incY);
void cblas_saxpy(const int N, const float alpha, const float *X,
const int incX, float *Y, const int incY);
void catlas_saxpby(const int N, const float alpha, const float *X,
const int incX, const float beta, float *Y, const int incY);
void catlas_sset
(const int N, const float alpha, float *X, const int incX);
void cblas_dswap(const int N, double *X, const int incX,
double *Y, const int incY);
void cblas_dcopy(const int N, const double *X, const int incX,
double *Y, const int incY);
void cblas_daxpy(const int N, const double alpha, const double *X,
const int incX, double *Y, const int incY);
void catlas_daxpby(const int N, const double alpha, const double *X,
const int incX, const double beta, double *Y, const int incY);
void catlas_dset
(const int N, const double alpha, double *X, const int incX);
void cblas_cswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_ccopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_caxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void catlas_caxpby(const int N, const void *alpha, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void catlas_cset
(const int N, const void *alpha, void *X, const int incX);
void cblas_zswap(const int N, void *X, const int incX,
void *Y, const int incY);
void cblas_zcopy(const int N, const void *X, const int incX,
void *Y, const int incY);
void cblas_zaxpy(const int N, const void *alpha, const void *X,
const int incX, void *Y, const int incY);
void catlas_zaxpby(const int N, const void *alpha, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void catlas_zset
(const int N, const void *alpha, void *X, const int incX);
/*
* Routines with S and D prefix only
*/
void cblas_srotg(float *a, float *b, float *c, float *s);
void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P);
void cblas_srot(const int N, float *X, const int incX,
float *Y, const int incY, const float c, const float s);
void cblas_srotm(const int N, float *X, const int incX,
float *Y, const int incY, const float *P);
void cblas_drotg(double *a, double *b, double *c, double *s);
void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P);
void cblas_drot(const int N, double *X, const int incX,
double *Y, const int incY, const double c, const double s);
void cblas_drotm(const int N, double *X, const int incX,
double *Y, const int incY, const double *P);
/*
* Routines with S D C Z CS and ZD prefixes
*/
void cblas_sscal(const int N, const float alpha, float *X, const int incX);
void cblas_dscal(const int N, const double alpha, double *X, const int incX);
void cblas_cscal(const int N, const void *alpha, void *X, const int incX);
void cblas_zscal(const int N, const void *alpha, void *X, const int incX);
void cblas_csscal(const int N, const float alpha, void *X, const int incX);
void cblas_zdscal(const int N, const double alpha, void *X, const int incX);
/*
* Extra reference routines provided by ATLAS, but not mandated by the standard
*/
void cblas_crotg(void *a, void *b, void *c, void *s);
void cblas_zrotg(void *a, void *b, void *c, void *s);
void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY,
const float c, const float s);
void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY,
const double c, const double s);
/*
* ===========================================================================
* Prototypes for level 2 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *X, const int incX, const float beta,
float *Y, const int incY);
void cblas_sgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const float alpha,
const float *A, const int lda, const float *X,
const int incX, const float beta, float *Y, const int incY);
void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda,
float *X, const int incX);
void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *A, const int lda, float *X,
const int incX);
void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const float *A, const int lda,
float *X, const int incX);
void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const float *Ap, float *X, const int incX);
void cblas_dgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *X, const int incX, const double beta,
double *Y, const int incY);
void cblas_dgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const double alpha,
const double *A, const int lda, const double *X,
const int incX, const double beta, double *Y, const int incY);
void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda,
double *X, const int incX);
void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *A, const int lda, double *X,
const int incX);
void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const double *A, const int lda,
double *X, const int incX);
void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const double *Ap, double *X, const int incX);
void cblas_cgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_cgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_zgemv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *X, const int incX, const void *beta,
void *Y, const int incY);
void cblas_zgbmv(const enum CBLAS_ORDER Order,
const enum CBLAS_TRANSPOSE TransA, const int M, const int N,
const int KL, const int KU, const void *alpha,
const void *A, const int lda, const void *X,
const int incX, const void *beta, void *Y, const int incY);
void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda,
void *X, const int incX);
void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *A, const int lda, void *X,
const int incX);
void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const int K, const void *A, const int lda,
void *X, const int incX);
void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag,
const int N, const void *Ap, void *X, const int incX);
/*
* Routines with S and D prefixes only
*/
void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const float alpha, const float *A,
const int lda, const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *Ap,
const float *X, const int incX,
const float beta, float *Y, const int incY);
void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N,
const float alpha, const float *X, const int incX,
const float *Y, const int incY, float *A, const int lda);
void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *A, const int lda);
void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, float *Ap);
void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *A,
const int lda);
void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const float *X,
const int incX, const float *Y, const int incY, float *Ap);
void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const double alpha, const double *A,
const int lda, const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *Ap,
const double *X, const int incX,
const double beta, double *Y, const int incY);
void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N,
const double alpha, const double *X, const int incX,
const double *Y, const int incY, double *A, const int lda);
void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *A, const int lda);
void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, double *Ap);
void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *A,
const int lda);
void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const double *X,
const int incX, const double *Y, const int incY, double *Ap);
/*
* Routines with C and Z prefixes only
*/
void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const float alpha, const void *X,
const int incX, void *Ap);
void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const int K, const void *alpha, const void *A,
const int lda, const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const void *alpha, const void *Ap,
const void *X, const int incX,
const void *beta, void *Y, const int incY);
void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X, const int incX,
void *A, const int lda);
void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const int N, const double alpha, const void *X,
const int incX, void *Ap);
void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *A, const int lda);
void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N,
const void *alpha, const void *X, const int incX,
const void *Y, const int incY, void *Ap);
/*
* ===========================================================================
* Prototypes for level 3 BLAS
* ===========================================================================
*/
/*
* Routines with standard 4 prefixes (S, D, C, Z)
*/
void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const float alpha, const float *A,
const int lda, const float *B, const int ldb,
const float beta, float *C, const int ldc);
void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float beta, float *C, const int ldc);
void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const float *A, const int lda,
const float *B, const int ldb, const float beta,
float *C, const int ldc);
void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const float alpha, const float *A, const int lda,
float *B, const int ldb);
void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const double alpha, const double *A,
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);
void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double beta, double *C, const int ldc);
void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const double *A, const int lda,
const double *B, const int ldb, const double beta,
double *C, const int ldc);
void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const double alpha, const double *A, const int lda,
double *B, const int ldb);
void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_TRANSPOSE TransB, const int M, const int N,
const int K, const void *alpha, const void *A,
const int lda, const void *B, const int ldb,
const void *beta, void *C, const int ldc);
void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *beta, void *C, const int ldc);
void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA,
const enum CBLAS_DIAG Diag, const int M, const int N,
const void *alpha, const void *A, const int lda,
void *B, const int ldb);
/*
* Routines with prefixes C and Z only
*/
void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const float alpha, const void *A, const int lda,
const float beta, void *C, const int ldc);
void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const float beta,
void *C, const int ldc);
void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side,
const enum CBLAS_UPLO Uplo, const int M, const int N,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const void *beta,
void *C, const int ldc);
void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const double alpha, const void *A, const int lda,
const double beta, void *C, const int ldc);
void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
const enum CBLAS_TRANSPOSE Trans, const int N, const int K,
const void *alpha, const void *A, const int lda,
const void *B, const int ldb, const double beta,
void *C, const int ldc);
int cblas_errprn(int ierr, int info, char *form, ...);
#endif /* end #ifdef CBLAS_ENUM_ONLY */
#endif

View File

@@ -0,0 +1,47 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemmSmSmSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Sm, Sm, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, NT)
}
func BenchmarkDgemmMedLgMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Lg, Med, NT, NT)
}
func BenchmarkDgemmLgLgLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Lg, NT, NT)
}
func BenchmarkDgemmLgSmLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Sm, Lg, NT, NT)
}
func BenchmarkDgemmLgLgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Sm, NT, NT)
}
func BenchmarkDgemmHgHgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Hg, Hg, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMedTNT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, NT)
}
func BenchmarkDgemmMedMedMedNTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, T)
}
func BenchmarkDgemmMedMedMedTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, T)
}

View File

@@ -0,0 +1,87 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemvSmSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 2, 3)
}
func BenchmarkDgemvSmSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 2, 3)
}
func BenchmarkDgemvMedMedNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 2, 3)
}
func BenchmarkDgemvMedMedTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 2, 3)
}
func BenchmarkDgemvLgLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 2, 3)
}
func BenchmarkDgemvLgSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 2, 3)
}
func BenchmarkDgemvSmLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 2, 3)
}
func BenchmarkDgemvSmLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 2, 3)
}

View File

@@ -0,0 +1,47 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgerSmSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 1, 1)
}
func BenchmarkDgerSmSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 2, 3)
}
func BenchmarkDgerMedMedInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 1, 1)
}
func BenchmarkDgerMedMedIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 2, 3)
}
func BenchmarkDgerLgLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 1, 1)
}
func BenchmarkDgerLgLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 2, 3)
}
func BenchmarkDgerLgSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 1, 1)
}
func BenchmarkDgerLgSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 2, 3)
}
func BenchmarkDgerSmLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 1, 1)
}
func BenchmarkDgerSmLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 2, 3)
}

95
blas/cgo/doc.go Normal file
View File

@@ -0,0 +1,95 @@
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate go run generate_blas.go
// Ensure changes made to blas/cgo are reflected in blas/native where relevant.
/*
Package cgo provides bindings to a C BLAS library. This wrapper interface
panics when the input arguments are invalid as per the standard, for example
if a vector increment is zero. Please note that the treatment of NaN values
is not specified, and differs among the BLAS implementations.
github.com/gonum/blas/blas64 provides helpful wrapper functions to the BLAS
interface. The rest of this text describes the layout of the data for the input types.
Please note that in the function documentation, x[i] refers to the i^th element
of the vector, which will be different from the i^th element of the slice if
incX != 1.
Vector arguments are effectively strided slices. They have two input arguments,
a number of elements, n, and an increment, incX. The increment specifies the
distance between elements of the vector. The actual Go slice may be longer
than necessary.
The increment may be positive or negative, except in functions with only
a single vector argument where the increment may only be positive. If the increment
is negative, s[0] is the last element in the slice. Note that this is not the same
as counting backward from the end of the slice, as len(s) may be longer than
necessary. So, for example, if n = 5 and incX = 3, the elements of s are
[0 * * 1 * * 2 * * 3 * * 4 * * * ...]
where elements are never accessed. If incX = -3, the same elements are
accessed, just in reverse order (4, 3, 2, 1, 0).
Dense matrices are specified by a number of rows, a number of columns, and a stride.
The stride specifies the number of entries in the slice between the first element
of successive rows. The stride must be at least as large as the number of columns
but may be longer.
[a00 ... a0n a0* ... a1stride-1 a21 ... amn am* ... amstride-1]
Thus, dense[i*ld + j] refers to the {i, j}th element of the matrix.
Symmetric and triangular matrices (non-packed) are stored identically to Dense,
except that only elements in one triangle of the matrix are accessed.
Packed symmetric and packed triangular matrices are laid out with the entries
condensed such that all of the unreferenced elements are removed. So, the upper triangular
matrix
[
1 2 3
0 4 5
0 0 6
]
and the lower-triangular matrix
[
1 0 0
2 3 0
4 5 6
]
will both be compacted as [1 2 3 4 5 6]. The (i, j) element of the original
dense matrix can be found at element i*n - (i-1)*i/2 + j for upper triangular,
and at element i * (i+1) /2 + j for lower triangular.
Banded matrices are laid out in a compact format, constructed by removing the
zeros in the rows and aligning the diagonals. For example, the matrix
[
1 2 3 0 0 0
4 5 6 7 0 0
0 8 9 10 11 0
0 0 12 13 14 15
0 0 0 16 17 18
0 0 0 0 19 20
]
implicitly becomes ( entries are never accessed)
[
* 1 2 3
4 5 6 7
8 9 10 11
12 13 14 15
16 17 18 *
19 20 * *
]
which is given to the BLAS routine as [ 1 2 3 4 ...].
See http://www.crest.iu.edu/research/mtl/reference/html/banded.html
for more information
*/
package cgo
// BUG(btracey): The cgo package is intrinsically dependent on the underlying C
// implementation. The BLAS standard is silent on a number of behaviors, including
// but not limited to how NaN values are treated. For this reason the result of
// computations performed by the cgo BLAS package may disagree with the results
// produced by the native BLAS package. The cgo package is tested against OpenBLAS;
// use of other backing BLAS C libraries may result in test failure because of this.

View File

@@ -0,0 +1,54 @@
// Copyright ©2017 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
package cgo
import (
"strconv"
"testing"
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
func BenchmarkDtrmv(b *testing.B) {
for _, n := range []int{testblas.MediumMat, testblas.LargeMat} {
for _, incX := range []int{1, 5} {
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
for _, unit := range []blas.Diag{blas.NonUnit, blas.Unit} {
var str string
if n == testblas.MediumMat {
str += "Med"
} else if n == testblas.LargeMat {
str += "Large"
}
str += "_Inc" + strconv.Itoa(incX)
if uplo == blas.Upper {
str += "_UP"
} else {
str += "_LO"
}
if trans == blas.NoTrans {
str += "_NT"
} else {
str += "_TR"
}
if unit == blas.NonUnit {
str += "_NU"
} else {
str += "_UN"
}
lda := n
b.Run(str, func(b *testing.B) {
testblas.DtrmvBenchmark(b, Implementation{}, n, lda, incX, uplo, trans, unit)
})
}
}
}
}
}
}

978
blas/cgo/generate_blas.go Normal file
View File

@@ -0,0 +1,978 @@
// Copyright ©2016 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// generate_blas creates a blas.go file from the provided C header file
// with optionally added documentation from the documentation package.
package main
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"io/ioutil"
"log"
"strings"
"text/template"
"github.com/cznic/cc"
"github.com/gonum/internal/binding"
)
const (
header = "cblas.h"
documentation = "../native"
target = "blas.go"
typ = "Implementation"
prefix = "cblas_"
warning = "Float32 implementations are autogenerated and not directly tested."
)
const (
cribDocs = true
elideRepeat = true
noteOrigin = true
separateFuncs = false
)
var skip = map[string]bool{
"cblas_errprn": true,
"cblas_srotg": true,
"cblas_srotmg": true,
"cblas_srotm": true,
"cblas_drotg": true,
"cblas_drotmg": true,
"cblas_drotm": true,
"cblas_crotg": true,
"cblas_zrotg": true,
"cblas_cdotu_sub": true,
"cblas_cdotc_sub": true,
"cblas_zdotu_sub": true,
"cblas_zdotc_sub": true,
// ATLAS extensions.
"cblas_csrot": true,
"cblas_zdrot": true,
}
var cToGoType = map[string]string{
"int": "int",
"float": "float32",
"double": "float64",
}
var blasEnums = map[string]*template.Template{
"CBLAS_ORDER": template.Must(template.New("order").Parse("order")),
"CBLAS_DIAG": template.Must(template.New("diag").Parse("blas.Diag")),
"CBLAS_TRANSPOSE": template.Must(template.New("trans").Parse("blas.Transpose")),
"CBLAS_UPLO": template.Must(template.New("uplo").Parse("blas.Uplo")),
"CBLAS_SIDE": template.Must(template.New("side").Parse("blas.Side")),
}
var cgoEnums = map[string]*template.Template{
"CBLAS_ORDER": template.Must(template.New("order").Parse("C.enum_CBLAS_ORDER(rowMajor)")),
"CBLAS_DIAG": template.Must(template.New("diag").Parse("C.enum_CBLAS_DIAG({{.}})")),
"CBLAS_TRANSPOSE": template.Must(template.New("trans").Parse("C.enum_CBLAS_TRANSPOSE({{.}})")),
"CBLAS_UPLO": template.Must(template.New("uplo").Parse("C.enum_CBLAS_UPLO({{.}})")),
"CBLAS_SIDE": template.Must(template.New("side").Parse("C.enum_CBLAS_SIDE({{.}})")),
}
var cgoTypes = map[binding.TypeKey]*template.Template{
{Kind: cc.Void, IsPointer: true}: template.Must(template.New("void*").Parse(
`unsafe.Pointer(&{{.}}{{if eq . "alpha" "beta"}}{{else}}[0]{{end}})`,
)),
}
var (
complex64Type = map[binding.TypeKey]*template.Template{
{Kind: cc.Void, IsPointer: true}: template.Must(template.New("void*").Parse(
`{{if eq . "alpha" "beta"}}complex64{{else}}[]complex64{{end}}`,
))}
complex128Type = map[binding.TypeKey]*template.Template{
{Kind: cc.Void, IsPointer: true}: template.Must(template.New("void*").Parse(
`{{if eq . "alpha" "beta"}}complex128{{else}}[]complex128{{end}}`,
))}
)
var names = map[string]string{
"uplo": "ul",
"trans": "t",
"transA": "tA",
"transB": "tB",
"side": "s",
"diag": "d",
}
func shorten(n string) string {
s, ok := names[n]
if ok {
return s
}
return n
}
func main() {
decls, err := binding.Declarations(header)
if err != nil {
log.Fatal(err)
}
var docs map[string]map[string][]*ast.Comment
if cribDocs {
docs, err = binding.DocComments(documentation)
if err != nil {
log.Fatal(err)
}
}
var buf bytes.Buffer
h, err := template.New("handwritten").Parse(handwritten)
if err != nil {
log.Fatal(err)
}
err = h.Execute(&buf, header)
if err != nil {
log.Fatal(err)
}
var n int
for _, d := range decls {
if !strings.HasPrefix(d.Name, prefix) || skip[d.Name] {
continue
}
if n != 0 && (separateFuncs || cribDocs) {
buf.WriteByte('\n')
}
n++
goSignature(&buf, d, docs[typ])
if noteOrigin {
fmt.Fprintf(&buf, "\t// declared at %s %s %s ...\n\n", d.Position(), d.Return, d.Name)
}
parameterChecks(&buf, d, parameterCheckRules)
buf.WriteByte('\t')
cgoCall(&buf, d)
buf.WriteString("}\n")
}
b, err := format.Source(buf.Bytes())
if err != nil {
log.Fatal(err)
}
err = ioutil.WriteFile(target, b, 0664)
if err != nil {
log.Fatal(err)
}
}
func goSignature(buf *bytes.Buffer, d binding.Declaration, docs map[string][]*ast.Comment) {
blasName := strings.TrimPrefix(d.Name, prefix)
goName := binding.UpperCaseFirst(blasName)
if docs != nil {
if doc, ok := docs[goName]; ok {
if strings.Contains(doc[len(doc)-1].Text, warning) {
doc = doc[:len(doc)-2]
}
for _, c := range doc {
buf.WriteString(c.Text)
buf.WriteByte('\n')
}
}
}
parameters := d.Parameters()
var voidPtrType map[binding.TypeKey]*template.Template
for _, p := range parameters {
if p.Kind() == cc.Ptr && p.Elem().Kind() == cc.Void {
switch {
case blasName[0] == 'c', blasName[1] == 'c' && blasName[0] != 'z':
voidPtrType = complex64Type
case blasName[0] == 'z', blasName[1] == 'z':
voidPtrType = complex128Type
}
break
}
}
fmt.Fprintf(buf, "func (%s) %s(", typ, goName)
c := 0
for i, p := range parameters {
if p.Kind() == cc.Enum && binding.GoTypeForEnum(p.Type(), "", blasEnums) == "order" {
continue
}
if c != 0 {
buf.WriteString(", ")
}
c++
n := shorten(binding.LowerCaseFirst(p.Name()))
var this, next string
if p.Kind() == cc.Enum {
this = binding.GoTypeForEnum(p.Type(), n, blasEnums)
} else {
this = binding.GoTypeFor(p.Type(), n, voidPtrType)
}
if elideRepeat && i < len(parameters)-1 && p.Type().Kind() == parameters[i+1].Type().Kind() {
p := parameters[i+1]
n := shorten(binding.LowerCaseFirst(p.Name()))
if p.Kind() == cc.Enum {
next = binding.GoTypeForEnum(p.Type(), n, blasEnums)
} else {
next = binding.GoTypeFor(p.Type(), n, voidPtrType)
}
}
if next == this {
buf.WriteString(n)
} else {
fmt.Fprintf(buf, "%s %s", n, this)
}
}
if d.Return.Kind() != cc.Void {
fmt.Fprintf(buf, ") %s {\n", cToGoType[d.Return.String()])
} else {
buf.WriteString(") {\n")
}
}
func parameterChecks(buf *bytes.Buffer, d binding.Declaration, rules []func(*bytes.Buffer, binding.Declaration, binding.Parameter) bool) {
done := make(map[int]bool)
for _, p := range d.Parameters() {
for i, r := range rules {
if done[i] {
continue
}
done[i] = r(buf, d, p)
}
}
}
func cgoCall(buf *bytes.Buffer, d binding.Declaration) {
if d.Return.Kind() != cc.Void {
fmt.Fprintf(buf, "return %s(", cToGoType[d.Return.String()])
}
fmt.Fprintf(buf, "C.%s(", d.Name)
for i, p := range d.Parameters() {
if i != 0 {
buf.WriteString(", ")
}
if p.Type().Kind() == cc.Enum {
buf.WriteString(binding.CgoConversionForEnum(shorten(binding.LowerCaseFirst(p.Name())), p.Type(), cgoEnums))
} else {
buf.WriteString(binding.CgoConversionFor(shorten(binding.LowerCaseFirst(p.Name())), p.Type(), cgoTypes))
}
}
if d.Return.Kind() != cc.Void {
buf.WriteString(")")
}
buf.WriteString(")\n")
}
var parameterCheckRules = []func(*bytes.Buffer, binding.Declaration, binding.Parameter) bool{
trans,
uplo,
diag,
side,
shape,
apShape,
zeroInc,
sidedShape,
mvShape,
rkShape,
gemmShape,
scalShape,
amaxShape,
nrmSumShape,
vectorShape,
othersShape,
noWork,
}
func amaxShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_isamax", "cblas_idamax", "cblas_icamax", "cblas_izamax":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` if n == 0 || incX < 0 {
return -1
}
if incX > 0 && (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
`)
return true
}
func apShape(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
n := binding.LowerCaseFirst(p.Name())
if n != "ap" {
return false
}
fmt.Fprint(buf, ` if n*(n+1)/2 > len(ap) {
panic("blas: index of ap out of range")
}
`)
return true
}
func diag(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
if p.Name() != "Diag" {
return false
}
fmt.Fprint(buf, ` if d != blas.NonUnit && d != blas.Unit {
panic("blas: illegal diagonal")
}
`)
return true
}
func gemmShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_sgemm", "cblas_dgemm", "cblas_cgemm", "cblas_zgemm":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` var rowA, colA, rowB, colB int
if tA == blas.NoTrans {
rowA, colA = m, k
} else {
rowA, colA = k, m
}
if tB == blas.NoTrans {
rowB, colB = k, n
} else {
rowB, colB = n, k
}
if lda*(rowA-1)+colA > len(a) || lda < max(1, colA) {
panic("blas: index of a out of range")
}
if ldb*(rowB-1)+colB > len(b) || ldb < max(1, colB) {
panic("blas: index of b out of range")
}
if ldc*(m-1)+n > len(c) || ldc < max(1, n) {
panic("blas: index of c out of range")
}
`)
return true
}
func mvShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_sgbmv", "cblas_dgbmv", "cblas_cgbmv", "cblas_zgbmv",
"cblas_sgemv", "cblas_dgemv", "cblas_cgemv", "cblas_zgemv":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` var lenX, lenY int
if tA == blas.NoTrans {
lenX, lenY = n, m
} else {
lenX, lenY = m, n
}
if (incX > 0 && (lenX-1)*incX >= len(x)) || (incX < 0 && (1-lenX)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (lenY-1)*incY >= len(y)) || (incY < 0 && (1-lenY)*incY >= len(y)) {
panic("blas: y index out of range")
}
`)
return true
}
func noWork(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
var hasN, hasLda, hasLdb bool
for _, p := range d.Parameters() {
switch shorten(binding.LowerCaseFirst(p.Name())) {
case "n":
hasN = true
case "lda":
hasLda = true
case "ldb":
hasLdb = true
}
}
if !hasN || hasLda || hasLdb {
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
var value string
switch d.Return.String() {
case "int":
value = " -1"
case "float", "double":
value = " 0"
}
fmt.Fprintf(buf, ` if n == 0 {
return%s
}
`, value)
return true
}
func nrmSumShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_snrm2", "cblas_dnrm2", "cblas_scnrm2", "cblas_dznrm2",
"cblas_sasum", "cblas_dasum", "cblas_scasum", "cblas_dzasum":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` if incX < 0 {
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
`)
return true
}
func rkShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_ssyrk", "cblas_dsyrk", "cblas_csyrk", "cblas_zsyrk",
"cblas_ssyr2k", "cblas_dsyr2k", "cblas_csyr2k", "cblas_zsyr2k",
"cblas_cherk", "cblas_zherk", "cblas_cher2k", "cblas_zher2k":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` var row, col int
if t == blas.NoTrans {
row, col = n, k
} else {
row, col = k, n
}
`)
has := make(map[string]bool)
for _, p := range d.Parameters() {
if p.Kind() != cc.Ptr {
continue
}
has[shorten(binding.LowerCaseFirst(p.Name()))] = true
}
for _, label := range []string{"a", "b"} {
if has[label] {
fmt.Fprintf(buf, ` if ld%[1]s*(row-1)+col > len(%[1]s) || ld%[1]s < max(1, col) {
panic("blas: index of %[1]s out of range")
}
`, label)
}
}
if has["c"] {
fmt.Fprint(buf, ` if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
panic("blas: index of c out of range")
}
`)
}
return true
}
func scalShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_sscal", "cblas_dscal", "cblas_cscal", "cblas_zscal", "cblas_csscal":
default:
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
fmt.Fprint(buf, ` if incX < 0 {
return
}
if incX > 0 && (n-1)*incX >= len(x) {
panic("blas: x index out of range")
}
`)
return true
}
func shape(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
switch n := binding.LowerCaseFirst(p.Name()); n {
case "m", "n", "k", "kL", "kU":
fmt.Fprintf(buf, ` if %[1]s < 0 {
panic("blas: %[1]s < 0")
}
`, n)
return false
}
return false
}
func side(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
if p.Name() != "Side" {
return false
}
fmt.Fprint(buf, ` if s != blas.Left && s != blas.Right {
panic("blas: illegal side")
}
`)
return true
}
func sidedShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
var hasS, hasA, hasB, hasC bool
for _, p := range d.Parameters() {
switch shorten(binding.LowerCaseFirst(p.Name())) {
case "s":
hasS = true
case "a":
hasA = true
case "b":
hasB = true
case "c":
hasC = true
}
}
if !hasS {
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
if hasA && hasB {
fmt.Fprint(buf, ` var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic("blas: index of a out of range")
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic("blas: index of b out of range")
}
`)
} else {
return true
}
if hasC {
fmt.Fprint(buf, ` if ldc*(m-1)+n > len(c) || ldc < max(1, n) {
panic("blas: index of c out of range")
}
`)
}
return true
}
func trans(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch n := shorten(binding.LowerCaseFirst(p.Name())); n {
case "t", "tA", "tB":
switch {
case strings.HasPrefix(d.Name, "cblas_ch"), strings.HasPrefix(d.Name, "cblas_zh"):
fmt.Fprintf(buf, ` if %[1]s != blas.NoTrans && %[1]s != blas.ConjTrans {
panic("blas: illegal transpose")
}
`, n)
case strings.HasPrefix(d.Name, "cblas_cs"), strings.HasPrefix(d.Name, "cblas_zs"):
fmt.Fprintf(buf, ` if %[1]s != blas.NoTrans && %[1]s != blas.Trans {
panic("blas: illegal transpose")
}
`, n)
default:
fmt.Fprintf(buf, ` if %[1]s != blas.NoTrans && %[1]s != blas.Trans && %[1]s != blas.ConjTrans {
panic("blas: illegal transpose")
}
`, n)
}
}
return false
}
func uplo(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
if p.Name() != "Uplo" {
return false
}
fmt.Fprint(buf, ` if ul != blas.Upper && ul != blas.Lower {
panic("blas: illegal triangle")
}
`)
return true
}
func vectorShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_sgbmv", "cblas_dgbmv", "cblas_cgbmv", "cblas_zgbmv",
"cblas_sgemv", "cblas_dgemv", "cblas_cgemv", "cblas_zgemv",
"cblas_sscal", "cblas_dscal", "cblas_cscal", "cblas_zscal", "cblas_csscal",
"cblas_isamax", "cblas_idamax", "cblas_icamax", "cblas_izamax",
"cblas_snrm2", "cblas_dnrm2", "cblas_scnrm2", "cblas_dznrm2",
"cblas_sasum", "cblas_dasum", "cblas_scasum", "cblas_dzasum":
return true
}
var hasN, hasM, hasIncX, hasIncY bool
for _, p := range d.Parameters() {
switch shorten(binding.LowerCaseFirst(p.Name())) {
case "n":
hasN = true
case "m":
hasM = true
case "incX":
hasIncX = true
case "incY":
hasIncY = true
}
}
if !hasN && !hasM {
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
var label string
if hasM {
label = "m"
} else {
label = "n"
}
if hasIncX {
fmt.Fprintf(buf, ` if (incX > 0 && (%[1]s-1)*incX >= len(x)) || (incX < 0 && (1-%[1]s)*incX >= len(x)) {
panic("blas: x index out of range")
}
`, label)
}
if hasIncY {
fmt.Fprint(buf, ` if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
`)
}
return true
}
func zeroInc(buf *bytes.Buffer, _ binding.Declaration, p binding.Parameter) bool {
switch n := binding.LowerCaseFirst(p.Name()); n {
case "incX":
fmt.Fprintf(buf, ` if incX == 0 {
panic("blas: zero x index increment")
}
`)
case "incY":
fmt.Fprintf(buf, ` if incY == 0 {
panic("blas: zero y index increment")
}
`)
return true
}
return false
}
func othersShape(buf *bytes.Buffer, d binding.Declaration, p binding.Parameter) bool {
switch d.Name {
case "cblas_sgemm", "cblas_dgemm", "cblas_cgemm", "cblas_zgemm",
"cblas_ssyrk", "cblas_dsyrk", "cblas_csyrk", "cblas_zsyrk",
"cblas_ssyr2k", "cblas_dsyr2k", "cblas_csyr2k", "cblas_zsyr2k",
"cblas_cherk", "cblas_zherk", "cblas_cher2k", "cblas_zher2k":
return true
}
has := make(map[string]bool)
for _, p := range d.Parameters() {
has[shorten(binding.LowerCaseFirst(p.Name()))] = true
}
if !has["a"] || has["s"] {
return true
}
if d.CParameters[len(d.CParameters)-1] != p.Parameter {
return false // Come back later.
}
switch {
case has["kL"] && has["kU"]:
fmt.Fprintf(buf, ` if lda*(m-1)+kL+kU+1 > len(a) || lda < kL+kU+1 {
panic("blas: index of a out of range")
}
`)
case has["m"]:
fmt.Fprintf(buf, ` if lda*(m-1)+n > len(a) || lda < max(1, n) {
panic("blas: index of a out of range")
}
`)
case has["k"]:
fmt.Fprintf(buf, ` if lda*(n-1)+k+1 > len(a) || lda < k+1 {
panic("blas: index of a out of range")
}
`)
default:
fmt.Fprintf(buf, ` if lda*(n-1)+n > len(a) || lda < max(1, n) {
panic("blas: index of a out of range")
}
`)
}
return true
}
const handwritten = `// Do not manually edit this file. It was created by the generate_blas.go from {{.}}.
// Copyright ©2014 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package cgo
/*
#cgo CFLAGS: -g -O2
#include "{{.}}"
*/
import "C"
import (
"unsafe"
"github.com/gonum/blas"
)
// Type check assertions:
var (
_ blas.Float32 = Implementation{}
_ blas.Float64 = Implementation{}
_ blas.Complex64 = Implementation{}
_ blas.Complex128 = Implementation{}
)
// Type order is used to specify the matrix storage format. We still interact with
// an API that allows client calls to specify order, so this is here to document that fact.
type order int
const (
rowMajor order = 101 + iota
)
func max(a, b int) int {
if a > b {
return a
}
return b
}
type Implementation struct{}
// Special cases...
type srotmParams struct {
flag float32
h [4]float32
}
type drotmParams struct {
flag float64
h [4]float64
}
func (Implementation) Srotg(a float32, b float32) (c float32, s float32, r float32, z float32) {
C.cblas_srotg((*C.float)(&a), (*C.float)(&b), (*C.float)(&c), (*C.float)(&s))
return c, s, a, b
}
func (Implementation) Srotmg(d1 float32, d2 float32, b1 float32, b2 float32) (p blas.SrotmParams, rd1 float32, rd2 float32, rb1 float32) {
var pi srotmParams
C.cblas_srotmg((*C.float)(&d1), (*C.float)(&d2), (*C.float)(&b1), C.float(b2), (*C.float)(unsafe.Pointer(&pi)))
return blas.SrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
}
func (Implementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if p.Flag < blas.Identity || p.Flag > blas.Diagonal {
panic("blas: illegal blas.Flag value")
}
if n == 0 {
return
}
pi := srotmParams{
flag: float32(p.Flag),
h: p.H,
}
C.cblas_srotm(C.int(n), (*C.float)(&x[0]), C.int(incX), (*C.float)(&y[0]), C.int(incY), (*C.float)(unsafe.Pointer(&pi)))
}
func (Implementation) Drotg(a float64, b float64) (c float64, s float64, r float64, z float64) {
C.cblas_drotg((*C.double)(&a), (*C.double)(&b), (*C.double)(&c), (*C.double)(&s))
return c, s, a, b
}
func (Implementation) Drotmg(d1 float64, d2 float64, b1 float64, b2 float64) (p blas.DrotmParams, rd1 float64, rd2 float64, rb1 float64) {
var pi drotmParams
C.cblas_drotmg((*C.double)(&d1), (*C.double)(&d2), (*C.double)(&b1), C.double(b2), (*C.double)(unsafe.Pointer(&pi)))
return blas.DrotmParams{Flag: blas.Flag(pi.flag), H: pi.h}, d1, d2, b1
}
func (Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if p.Flag < blas.Identity || p.Flag > blas.Diagonal {
panic("blas: illegal blas.Flag value")
}
if n == 0 {
return
}
pi := drotmParams{
flag: float64(p.Flag),
h: p.H,
}
C.cblas_drotm(C.int(n), (*C.double)(&x[0]), C.int(incX), (*C.double)(&y[0]), C.int(incY), (*C.double)(unsafe.Pointer(&pi)))
}
func (Implementation) Cdotu(n int, x []complex64, incX int, y []complex64, incY int) (dotu complex64) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if n == 0 {
return 0
}
C.cblas_cdotu_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotu))
return dotu
}
func (Implementation) Cdotc(n int, x []complex64, incX int, y []complex64, incY int) (dotc complex64) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if n == 0 {
return 0
}
C.cblas_cdotc_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotc))
return dotc
}
func (Implementation) Zdotu(n int, x []complex128, incX int, y []complex128, incY int) (dotu complex128) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if n == 0 {
return 0
}
C.cblas_zdotu_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotu))
return dotu
}
func (Implementation) Zdotc(n int, x []complex128, incX int, y []complex128, incY int) (dotc complex128) {
if n < 0 {
panic("blas: n < 0")
}
if incX == 0 {
panic("blas: zero x index increment")
}
if incY == 0 {
panic("blas: zero y index increment")
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic("blas: x index out of range")
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic("blas: y index out of range")
}
if n == 0 {
return 0
}
C.cblas_zdotc_sub(C.int(n), unsafe.Pointer(&x[0]), C.int(incX), unsafe.Pointer(&y[0]), C.int(incY), unsafe.Pointer(&dotc))
return dotc
}
// Generated cases ...
`

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,57 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
var impl Implementation
func TestDasum(t *testing.T) {
testblas.DasumTest(t, impl)
}
func TestDaxpy(t *testing.T) {
testblas.DaxpyTest(t, impl)
}
func TestDdot(t *testing.T) {
testblas.DdotTest(t, impl)
}
func TestDnrm2(t *testing.T) {
testblas.Dnrm2Test(t, impl)
}
func TestIdamax(t *testing.T) {
testblas.IdamaxTest(t, impl)
}
func TestDswap(t *testing.T) {
testblas.DswapTest(t, impl)
}
func TestDcopy(t *testing.T) {
testblas.DcopyTest(t, impl)
}
func TestDrotg(t *testing.T) {
testblas.DrotgTest(t, impl)
}
func TestDrotmg(t *testing.T) {
testblas.DrotmgTest(t, impl)
}
func TestDrot(t *testing.T) {
testblas.DrotTest(t, impl)
}
func TestDrotm(t *testing.T) {
testblas.DrotmTest(t, impl)
}
func TestDscal(t *testing.T) {
testblas.DscalTest(t, impl)
}

View File

@@ -0,0 +1,75 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemv(t *testing.T) {
testblas.DgemvTest(t, impl)
}
func TestDger(t *testing.T) {
testblas.DgerTest(t, impl)
}
func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, impl)
}
func TestDtxmv(t *testing.T) {
testblas.DtxmvTest(t, impl)
}
func TestDgbmv(t *testing.T) {
testblas.DgbmvTest(t, impl)
}
func TestDtbsv(t *testing.T) {
testblas.DtbsvTest(t, impl)
}
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, impl)
}
func TestDtrsv(t *testing.T) {
testblas.DtrsvTest(t, impl)
}
func TestDsyr(t *testing.T) {
testblas.DsyrTest(t, impl)
}
func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, impl)
}
func TestDtrmv(t *testing.T) {
testblas.DtrmvTest(t, impl)
}
func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, impl)
}
func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, impl)
}
func TestDspr(t *testing.T) {
testblas.DsprTest(t, impl)
}
func TestDspmv(t *testing.T) {
testblas.DspmvTest(t, impl)
}
func TestDtpsv(t *testing.T) {
testblas.DtpsvTest(t, impl)
}
func TestDtmpv(t *testing.T) {
testblas.DtpmvTest(t, impl)
}

View File

@@ -0,0 +1,31 @@
package cgo
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemm(t *testing.T) {
testblas.TestDgemm(t, impl)
}
func TestDsymm(t *testing.T) {
testblas.DsymmTest(t, impl)
}
func TestDtrsm(t *testing.T) {
testblas.DtrsmTest(t, impl)
}
func TestDsyrk(t *testing.T) {
testblas.DsyrkTest(t, impl)
}
func TestDsyr2k(t *testing.T) {
testblas.Dsyr2kTest(t, impl)
}
func TestDtrmm(t *testing.T) {
testblas.DtrmmTest(t, impl)
}

22
blas/native/bench_test.go Normal file
View File

@@ -0,0 +1,22 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
const (
Sm = testblas.SmallMat
Med = testblas.MediumMat
Lg = testblas.LargeMat
Hg = testblas.HugeMat
)
const (
T = blas.Trans
NT = blas.NoTrans
)

276
blas/native/dgemm.go Normal file
View File

@@ -0,0 +1,276 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"runtime"
"sync"
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f64"
)
// Dgemm computes
// C = beta * C + alpha * A * B,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed.
func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
panic(badTranspose)
}
aTrans := tA == blas.Trans || tA == blas.ConjTrans
if aTrans {
checkMatrix64(k, m, a, lda)
} else {
checkMatrix64(m, k, a, lda)
}
bTrans := tB == blas.Trans || tB == blas.ConjTrans
if bTrans {
checkMatrix64(n, k, b, ldb)
} else {
checkMatrix64(k, n, b, ldb)
}
checkMatrix64(m, n, c, ldc)
// scale c
if beta != 1 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
} else {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
}
}
dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
}
func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
// dgemmParallel computes a parallel matrix multiplication by partitioning
// a and b into sub-blocks, and updating c with the multiplication of the sub-block
// In all cases,
// A = [ A_11 A_12 ... A_1j
// A_21 A_22 ... A_2j
// ...
// A_i1 A_i2 ... A_ij]
//
// and same for B. All of the submatrix sizes are blockSize×blockSize except
// at the edges.
//
// In all cases, there is one dimension for each matrix along which
// C must be updated sequentially.
// Cij = \sum_k Aik Bki, (A * B)
// Cij = \sum_k Aki Bkj, (A^T * B)
// Cij = \sum_k Aik Bjk, (A * B^T)
// Cij = \sum_k Aki Bjk, (A^T * B^T)
//
// This code computes one {i, j} block sequentially along the k dimension,
// and computes all of the {i, j} blocks concurrently. This
// partitioning allows Cij to be updated in-place without race-conditions.
// Instead of launching a goroutine for each possible concurrent computation,
// a number of worker goroutines are created and channels are used to pass
// available and completed cases.
//
// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
// multiplies, though this code does not copy matrices to attempt to eliminate
// cache misses.
maxKLen := k
parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
if parBlocks < minParBlock {
// The matrix multiplication is small in the dimensions where it can be
// computed concurrently. Just do it in serial.
dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
}
nWorkers := runtime.GOMAXPROCS(0)
if parBlocks < nWorkers {
nWorkers = parBlocks
}
// There is a tradeoff between the workers having to wait for work
// and a large buffer making operations slow.
buf := buffMul * nWorkers
if buf > parBlocks {
buf = parBlocks
}
sendChan := make(chan subMul, buf)
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
// channel is finally closed, it signals to the waitgroup that it has finished
// computing.
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Make local copies of otherwise global variables to reduce shared memory.
// This has a noticable effect on benchmarks in some cases.
alpha := alpha
aTrans := aTrans
bTrans := bTrans
m := m
n := n
for sub := range sendChan {
i := sub.i
j := sub.j
leni := blockSize
if i+leni > m {
leni = m - i
}
lenj := blockSize
if j+lenj > n {
lenj = n - j
}
cSub := sliceView64(c, ldc, i, j, leni, lenj)
// Compute A_ik B_kj for all k
for k := 0; k < maxKLen; k += blockSize {
lenk := blockSize
if k+lenk > maxKLen {
lenk = maxKLen - k
}
var aSub, bSub []float64
if aTrans {
aSub = sliceView64(a, lda, k, i, lenk, leni)
} else {
aSub = sliceView64(a, lda, i, k, leni, lenk)
}
if bTrans {
bSub = sliceView64(b, ldb, j, k, lenj, lenk)
} else {
bSub = sliceView64(b, ldb, k, j, lenk, lenj)
}
dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
}
}
}()
}
// Send out all of the {i, j} subblocks for computation.
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
sendChan <- subMul{
i: i,
j: j,
}
}
}
close(sendChan)
wg.Wait()
}
// dgemmSerial is serial matrix multiply
func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
switch {
case !aTrans && !bTrans:
dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case aTrans && !bTrans:
dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case !aTrans && bTrans:
dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case aTrans && bTrans:
dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
default:
panic("unreachable")
}
}
// dgemmSerial where neither a nor b are transposed
func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for l, v := range a[i*lda : i*lda+k] {
tmp := alpha * v
if tmp != 0 {
f64.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
}
}
}
}
// dgemmSerial where neither a is transposed and b is not
func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < k; l++ {
btmp := b[l*ldb : l*ldb+n]
for i, v := range a[l*lda : l*lda+m] {
tmp := alpha * v
if tmp != 0 {
ctmp := c[i*ldc : i*ldc+n]
f64.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
}
}
}
}
// dgemmSerial where neither a is not transposed and b is
func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < m; i++ {
atmp := a[i*lda : i*lda+k]
ctmp := c[i*ldc : i*ldc+n]
for j := 0; j < n; j++ {
ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
}
}
}
// dgemmSerial where both are transposed
func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < k; l++ {
for i, v := range a[l*lda : l*lda+m] {
tmp := alpha * v
if tmp != 0 {
ctmp := c[i*ldc : i*ldc+n]
f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
}
}
}
}
func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
return a[i*lda+j : (i+r-1)*lda+j+c]
}
func checkMatrix64(m, n int, a []float64, lda int) {
if m < 0 {
panic("blas: rows < 0")
}
if n < 0 {
panic("blas: cols < 0")
}
if lda < n {
panic("blas: illegal stride")
}
if len(a) < (m-1)*lda+n {
panic("blas: insufficient matrix slice length")
}
}

View File

@@ -0,0 +1,51 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemmSmSmSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Sm, Sm, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, NT)
}
func BenchmarkDgemmMedLgMed(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Lg, Med, NT, NT)
}
func BenchmarkDgemmLgLgLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Lg, NT, NT)
}
func BenchmarkDgemmLgSmLg(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Sm, Lg, NT, NT)
}
func BenchmarkDgemmLgLgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Lg, Lg, Sm, NT, NT)
}
func BenchmarkDgemmHgHgSm(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Hg, Hg, Sm, NT, NT)
}
func BenchmarkDgemmMedMedMedTNT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, NT)
}
func BenchmarkDgemmMedMedMedNTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, NT, T)
}
func BenchmarkDgemmMedMedMedTT(b *testing.B) {
testblas.DgemmBenchmark(b, impl, Med, Med, Med, T, T)
}

View File

@@ -0,0 +1,91 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgemvSmSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Sm, 2, 3)
}
func BenchmarkDgemvSmSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 1, 1)
}
func BenchmarkDgemvSmSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Sm, 2, 3)
}
func BenchmarkDgemvMedMedNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Med, Med, 2, 3)
}
func BenchmarkDgemvMedMedTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 1, 1)
}
func BenchmarkDgemvMedMedTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Med, Med, 2, 3)
}
func BenchmarkDgemvLgLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 1, 1)
}
func BenchmarkDgemvLgLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Lg, 2, 3)
}
func BenchmarkDgemvLgSmNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Lg, Sm, 2, 3)
}
func BenchmarkDgemvLgSmTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 1, 1)
}
func BenchmarkDgemvLgSmTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Lg, Sm, 2, 3)
}
func BenchmarkDgemvSmLgNoTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgNoTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, NT, Sm, Lg, 2, 3)
}
func BenchmarkDgemvSmLgTransInc1(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 1, 1)
}
func BenchmarkDgemvSmLgTransIncN(b *testing.B) {
testblas.DgemvBenchmark(b, impl, T, Sm, Lg, 2, 3)
}

View File

@@ -0,0 +1,51 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func BenchmarkDgerSmSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 1, 1)
}
func BenchmarkDgerSmSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Sm, 2, 3)
}
func BenchmarkDgerMedMedInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 1, 1)
}
func BenchmarkDgerMedMedIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Med, Med, 2, 3)
}
func BenchmarkDgerLgLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 1, 1)
}
func BenchmarkDgerLgLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Lg, 2, 3)
}
func BenchmarkDgerLgSmInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 1, 1)
}
func BenchmarkDgerLgSmIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Lg, Sm, 2, 3)
}
func BenchmarkDgerSmLgInc1(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 1, 1)
}
func BenchmarkDgerSmLgIncN(b *testing.B) {
testblas.DgerBenchmark(b, impl, Sm, Lg, 2, 3)
}

88
blas/native/doc.go Normal file
View File

@@ -0,0 +1,88 @@
// Copyright ©2015 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Ensure changes made to blas/native are reflected in blas/cgo where relevant.
/*
Package native is a Go implementation of the BLAS API. This implementation
panics when the input arguments are invalid as per the standard, for example
if a vector increment is zero. Please note that the treatment of NaN values
is not specified, and differs among the BLAS implementations.
github.com/gonum/blas/blas64 provides helpful wrapper functions to the BLAS
interface. The rest of this text describes the layout of the data for the input types.
Please note that in the function documentation, x[i] refers to the i^th element
of the vector, which will be different from the i^th element of the slice if
incX != 1.
See http://www.netlib.org/lapack/explore-html/d4/de1/_l_i_c_e_n_s_e_source.html
for more license information.
Vector arguments are effectively strided slices. They have two input arguments,
a number of elements, n, and an increment, incX. The increment specifies the
distance between elements of the vector. The actual Go slice may be longer
than necessary.
The increment may be positive or negative, except in functions with only
a single vector argument where the increment may only be positive. If the increment
is negative, s[0] is the last element in the slice. Note that this is not the same
as counting backward from the end of the slice, as len(s) may be longer than
necessary. So, for example, if n = 5 and incX = 3, the elements of s are
[0 * * 1 * * 2 * * 3 * * 4 * * * ...]
where elements are never accessed. If incX = -3, the same elements are
accessed, just in reverse order (4, 3, 2, 1, 0).
Dense matrices are specified by a number of rows, a number of columns, and a stride.
The stride specifies the number of entries in the slice between the first element
of successive rows. The stride must be at least as large as the number of columns
but may be longer.
[a00 ... a0n a0* ... a1stride-1 a21 ... amn am* ... amstride-1]
Thus, dense[i*ld + j] refers to the {i, j}th element of the matrix.
Symmetric and triangular matrices (non-packed) are stored identically to Dense,
except that only elements in one triangle of the matrix are accessed.
Packed symmetric and packed triangular matrices are laid out with the entries
condensed such that all of the unreferenced elements are removed. So, the upper triangular
matrix
[
1 2 3
0 4 5
0 0 6
]
and the lower-triangular matrix
[
1 0 0
2 3 0
4 5 6
]
will both be compacted as [1 2 3 4 5 6]. The (i, j) element of the original
dense matrix can be found at element i*n - (i-1)*i/2 + j for upper triangular,
and at element i * (i+1) /2 + j for lower triangular.
Banded matrices are laid out in a compact format, constructed by removing the
zeros in the rows and aligning the diagonals. For example, the matrix
[
1 2 3 0 0 0
4 5 6 7 0 0
0 8 9 10 11 0
0 0 12 13 14 15
0 0 0 16 17 18
0 0 0 0 19 20
]
implicitly becomes ( entries are never accessed)
[
* 1 2 3
4 5 6 7
8 9 10 11
12 13 14 15
16 17 18 *
19 20 * *
]
which is given to the BLAS routine as [ 1 2 3 4 ...].
See http://www.crest.iu.edu/research/mtl/reference/html/banded.html
for more information
*/
package native

View File

@@ -0,0 +1,54 @@
// Copyright ©2017 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
package native
import (
"strconv"
"testing"
"github.com/gonum/blas"
"github.com/gonum/blas/testblas"
)
func BenchmarkDtrmv(b *testing.B) {
for _, n := range []int{testblas.MediumMat, testblas.LargeMat} {
for _, incX := range []int{1, 5} {
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
for _, unit := range []blas.Diag{blas.NonUnit, blas.Unit} {
var str string
if n == testblas.MediumMat {
str += "Med"
} else if n == testblas.LargeMat {
str += "Large"
}
str += "_Inc" + strconv.Itoa(incX)
if uplo == blas.Upper {
str += "_UP"
} else {
str += "_LO"
}
if trans == blas.NoTrans {
str += "_NT"
} else {
str += "_TR"
}
if unit == blas.NonUnit {
str += "_NU"
} else {
str += "_UN"
}
lda := n
b.Run(str, func(b *testing.B) {
testblas.DtrmvBenchmark(b, Implementation{}, n, lda, incX, uplo, trans, unit)
})
}
}
}
}
}
}

View File

@@ -0,0 +1,155 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"errors"
"fmt"
"math"
)
func newGeneral64(r, c int) general64 {
return general64{
data: make([]float64, r*c),
rows: r,
cols: c,
stride: c,
}
}
type general64 struct {
data []float64
rows, cols int
stride int
}
// adds element-wise into receiver. rows and columns must match
func (g general64) add(h general64) {
if debug {
if g.rows != h.rows {
panic("blas: row size mismatch")
}
if g.cols != h.cols {
panic("blas: col size mismatch")
}
}
for i := 0; i < g.rows; i++ {
gtmp := g.data[i*g.stride : i*g.stride+g.cols]
for j, v := range h.data[i*h.stride : i*h.stride+h.cols] {
gtmp[j] += v
}
}
}
// at returns the value at the ith row and jth column. For speed reasons, the
// rows and columns are not bounds checked.
func (g general64) at(i, j int) float64 {
if debug {
if i < 0 || i >= g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j >= g.cols {
panic("blas: col out of bounds")
}
}
return g.data[i*g.stride+j]
}
func (g general64) check(c byte) error {
if g.rows < 0 {
return errors.New("blas: rows < 0")
}
if g.cols < 0 {
return errors.New("blas: cols < 0")
}
if g.stride < 1 {
return errors.New("blas: stride < 1")
}
if g.stride < g.cols {
return errors.New("blas: illegal stride")
}
if (g.rows-1)*g.stride+g.cols > len(g.data) {
return fmt.Errorf("blas: index of %c out of range", c)
}
return nil
}
func (g general64) clone() general64 {
data := make([]float64, len(g.data))
copy(data, g.data)
return general64{
data: data,
rows: g.rows,
cols: g.cols,
stride: g.stride,
}
}
// assumes they are the same size
func (g general64) copy(h general64) {
if debug {
if g.rows != h.rows {
panic("blas: row mismatch")
}
if g.cols != h.cols {
panic("blas: col mismatch")
}
}
for k := 0; k < g.rows; k++ {
copy(g.data[k*g.stride:(k+1)*g.stride], h.data[k*h.stride:(k+1)*h.stride])
}
}
func (g general64) equal(a general64) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if a.data[i] != v {
return false
}
}
return true
}
/*
// print is to aid debugging. Commented out to avoid fmt import
func (g general64) print() {
fmt.Println("r = ", g.rows, "c = ", g.cols, "stride: ", g.stride)
for i := 0; i < g.rows; i++ {
fmt.Println(g.data[i*g.stride : (i+1)*g.stride])
}
}
*/
func (g general64) view(i, j, r, c int) general64 {
if debug {
if i < 0 || i+r > g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j+c > g.cols {
panic("blas: col out of bounds")
}
}
return general64{
data: g.data[i*g.stride+j : (i+r-1)*g.stride+j+c],
rows: r,
cols: c,
stride: g.stride,
}
}
func (g general64) equalWithinAbs(a general64, tol float64) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if math.Abs(a.data[i]-v) > tol {
return false
}
}
return true
}

View File

@@ -0,0 +1,157 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"errors"
"fmt"
math "github.com/gonum/blas/native/internal/math32"
)
func newGeneral32(r, c int) general32 {
return general32{
data: make([]float32, r*c),
rows: r,
cols: c,
stride: c,
}
}
type general32 struct {
data []float32
rows, cols int
stride int
}
// adds element-wise into receiver. rows and columns must match
func (g general32) add(h general32) {
if debug {
if g.rows != h.rows {
panic("blas: row size mismatch")
}
if g.cols != h.cols {
panic("blas: col size mismatch")
}
}
for i := 0; i < g.rows; i++ {
gtmp := g.data[i*g.stride : i*g.stride+g.cols]
for j, v := range h.data[i*h.stride : i*h.stride+h.cols] {
gtmp[j] += v
}
}
}
// at returns the value at the ith row and jth column. For speed reasons, the
// rows and columns are not bounds checked.
func (g general32) at(i, j int) float32 {
if debug {
if i < 0 || i >= g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j >= g.cols {
panic("blas: col out of bounds")
}
}
return g.data[i*g.stride+j]
}
func (g general32) check(c byte) error {
if g.rows < 0 {
return errors.New("blas: rows < 0")
}
if g.cols < 0 {
return errors.New("blas: cols < 0")
}
if g.stride < 1 {
return errors.New("blas: stride < 1")
}
if g.stride < g.cols {
return errors.New("blas: illegal stride")
}
if (g.rows-1)*g.stride+g.cols > len(g.data) {
return fmt.Errorf("blas: index of %c out of range", c)
}
return nil
}
func (g general32) clone() general32 {
data := make([]float32, len(g.data))
copy(data, g.data)
return general32{
data: data,
rows: g.rows,
cols: g.cols,
stride: g.stride,
}
}
// assumes they are the same size
func (g general32) copy(h general32) {
if debug {
if g.rows != h.rows {
panic("blas: row mismatch")
}
if g.cols != h.cols {
panic("blas: col mismatch")
}
}
for k := 0; k < g.rows; k++ {
copy(g.data[k*g.stride:(k+1)*g.stride], h.data[k*h.stride:(k+1)*h.stride])
}
}
func (g general32) equal(a general32) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if a.data[i] != v {
return false
}
}
return true
}
/*
// print is to aid debugging. Commented out to avoid fmt import
func (g general32) print() {
fmt.Println("r = ", g.rows, "c = ", g.cols, "stride: ", g.stride)
for i := 0; i < g.rows; i++ {
fmt.Println(g.data[i*g.stride : (i+1)*g.stride])
}
}
*/
func (g general32) view(i, j, r, c int) general32 {
if debug {
if i < 0 || i+r > g.rows {
panic("blas: row out of bounds")
}
if j < 0 || j+c > g.cols {
panic("blas: col out of bounds")
}
}
return general32{
data: g.data[i*g.stride+j : (i+r-1)*g.stride+j+c],
rows: r,
cols: c,
stride: g.stride,
}
}
func (g general32) equalWithinAbs(a general32, tol float32) bool {
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
return false
}
for i, v := range g.data {
if math.Abs(a.data[i]-v) > tol {
return false
}
}
return true
}

View File

@@ -0,0 +1,113 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package math32 provides float32 versions of standard library math package
// routines used by gonum/blas/native.
package math32
import (
"math"
)
const (
unan = 0x7fc00000
uinf = 0x7f800000
uneginf = 0xff800000
mask = 0x7f8 >> 3
shift = 32 - 8 - 1
bias = 127
)
// Abs returns the absolute value of x.
//
// Special cases are:
// Abs(±Inf) = +Inf
// Abs(NaN) = NaN
func Abs(x float32) float32 {
switch {
case x < 0:
return -x
case x == 0:
return 0 // return correctly abs(-0)
}
return x
}
// Copysign returns a value with the magnitude
// of x and the sign of y.
func Copysign(x, y float32) float32 {
const sign = 1 << 31
return math.Float32frombits(math.Float32bits(x)&^sign | math.Float32bits(y)&sign)
}
// Hypot returns Sqrt(p*p + q*q), taking care to avoid
// unnecessary overflow and underflow.
//
// Special cases are:
// Hypot(±Inf, q) = +Inf
// Hypot(p, ±Inf) = +Inf
// Hypot(NaN, q) = NaN
// Hypot(p, NaN) = NaN
func Hypot(p, q float32) float32 {
// special cases
switch {
case IsInf(p, 0) || IsInf(q, 0):
return Inf(1)
case IsNaN(p) || IsNaN(q):
return NaN()
}
if p < 0 {
p = -p
}
if q < 0 {
q = -q
}
if p < q {
p, q = q, p
}
if p == 0 {
return 0
}
q = q / p
return p * Sqrt(1+q*q)
}
// Inf returns positive infinity if sign >= 0, negative infinity if sign < 0.
func Inf(sign int) float32 {
var v uint32
if sign >= 0 {
v = uinf
} else {
v = uneginf
}
return math.Float32frombits(v)
}
// IsInf reports whether f is an infinity, according to sign.
// If sign > 0, IsInf reports whether f is positive infinity.
// If sign < 0, IsInf reports whether f is negative infinity.
// If sign == 0, IsInf reports whether f is either infinity.
func IsInf(f float32, sign int) bool {
// Test for infinity by comparing against maximum float.
// To avoid the floating-point hardware, could use:
// x := math.Float32bits(f);
// return sign >= 0 && x == uinf || sign <= 0 && x == uneginf;
return sign >= 0 && f > math.MaxFloat32 || sign <= 0 && f < -math.MaxFloat32
}
// IsNaN reports whether f is an IEEE 754 ``not-a-number'' value.
func IsNaN(f float32) (is bool) {
// IEEE 754 says that only NaNs satisfy f != f.
// To avoid the floating-point hardware, could use:
// x := math.Float32bits(f);
// return uint32(x>>shift)&mask == mask && x != uinf && x != uneginf
return f != f
}
// NaN returns an IEEE 754 ``not-a-number'' value.
func NaN() float32 { return math.Float32frombits(unan) }

View File

@@ -0,0 +1,226 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package math32
import (
"math"
"testing"
"testing/quick"
"github.com/gonum/floats"
)
const tol = 1e-7
func TestAbs(t *testing.T) {
f := func(x float32) bool {
y := Abs(x)
return y == float32(math.Abs(float64(x)))
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestCopySign(t *testing.T) {
f := func(x struct{ X, Y float32 }) bool {
y := Copysign(x.X, x.Y)
return y == float32(math.Copysign(float64(x.X), float64(x.Y)))
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestHypot(t *testing.T) {
f := func(x struct{ X, Y float32 }) bool {
y := Hypot(x.X, x.Y)
if math.Hypot(float64(x.X), float64(x.Y)) > math.MaxFloat32 {
return true
}
return floats.EqualWithinRel(float64(y), math.Hypot(float64(x.X), float64(x.Y)), tol)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestInf(t *testing.T) {
if float64(Inf(1)) != math.Inf(1) || float64(Inf(-1)) != math.Inf(-1) {
t.Error("float32(inf) not infinite")
}
}
func TestIsInf(t *testing.T) {
posInf := float32(math.Inf(1))
negInf := float32(math.Inf(-1))
if !IsInf(posInf, 0) || !IsInf(negInf, 0) || !IsInf(posInf, 1) || !IsInf(negInf, -1) || IsInf(posInf, -1) || IsInf(negInf, 1) {
t.Error("unexpected isInf value")
}
f := func(x struct {
F float32
Sign int
}) bool {
y := IsInf(x.F, x.Sign)
return y == math.IsInf(float64(x.F), x.Sign)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestIsNaN(t *testing.T) {
f := func(x float32) bool {
y := IsNaN(x)
return y == math.IsNaN(float64(x))
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
func TestNaN(t *testing.T) {
if !math.IsNaN(float64(NaN())) {
t.Errorf("float32(nan) is a number: %f", NaN())
}
}
func TestSqrt(t *testing.T) {
f := func(x float32) bool {
y := Sqrt(x)
if IsNaN(y) && IsNaN(sqrt(x)) {
return true
}
return floats.EqualWithinRel(float64(y), float64(sqrt(x)), tol)
}
if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The original C code and the long comment below are
// from FreeBSD's /usr/src/lib/msun/src/e_sqrt.c and
// came with this notice. The go code is a simplified
// version of the original C.
//
// ====================================================
// Copyright (C) 1993 by Sun Microsystems, Inc. All rights reserved.
//
// Developed at SunPro, a Sun Microsystems, Inc. business.
// Permission to use, copy, modify, and distribute this
// software is freely granted, provided that this notice
// is preserved.
// ====================================================
//
// __ieee754_sqrt(x)
// Return correctly rounded sqrt.
// -----------------------------------------
// | Use the hardware sqrt if you have one |
// -----------------------------------------
// Method:
// Bit by bit method using integer arithmetic. (Slow, but portable)
// 1. Normalization
// Scale x to y in [1,4) with even powers of 2:
// find an integer k such that 1 <= (y=x*2**(2k)) < 4, then
// sqrt(x) = 2**k * sqrt(y)
// 2. Bit by bit computation
// Let q = sqrt(y) truncated to i bit after binary point (q = 1),
// i 0
// i+1 2
// s = 2*q , and y = 2 * ( y - q ). (1)
// i i i i
//
// To compute q from q , one checks whether
// i+1 i
//
// -(i+1) 2
// (q + 2 ) <= y. (2)
// i
// -(i+1)
// If (2) is false, then q = q ; otherwise q = q + 2 .
// i+1 i i+1 i
//
// With some algebraic manipulation, it is not difficult to see
// that (2) is equivalent to
// -(i+1)
// s + 2 <= y (3)
// i i
//
// The advantage of (3) is that s and y can be computed by
// i i
// the following recurrence formula:
// if (3) is false
//
// s = s , y = y ; (4)
// i+1 i i+1 i
//
// otherwise,
// -i -(i+1)
// s = s + 2 , y = y - s - 2 (5)
// i+1 i i+1 i i
//
// One may easily use induction to prove (4) and (5).
// Note. Since the left hand side of (3) contain only i+2 bits,
// it does not necessary to do a full (53-bit) comparison
// in (3).
// 3. Final rounding
// After generating the 53 bits result, we compute one more bit.
// Together with the remainder, we can decide whether the
// result is exact, bigger than 1/2ulp, or less than 1/2ulp
// (it will never equal to 1/2ulp).
// The rounding mode can be detected by checking whether
// huge + tiny is equal to huge, and whether huge - tiny is
// equal to huge for some floating point number "huge" and "tiny".
//
func sqrt(x float32) float32 {
// special cases
switch {
case x == 0 || IsNaN(x) || IsInf(x, 1):
return x
case x < 0:
return NaN()
}
ix := math.Float32bits(x)
// normalize x
exp := int((ix >> shift) & mask)
if exp == 0 { // subnormal x
for ix&1<<shift == 0 {
ix <<= 1
exp--
}
exp++
}
exp -= bias // unbias exponent
ix &^= mask << shift
ix |= 1 << shift
if exp&1 == 1 { // odd exp, double x to make it even
ix <<= 1
}
exp >>= 1 // exp = exp/2, exponent of square root
// generate sqrt(x) bit by bit
ix <<= 1
var q, s uint32 // q = sqrt(x)
r := uint32(1 << (shift + 1)) // r = moving bit from MSB to LSB
for r != 0 {
t := s + r
if t <= ix {
s = t + r
ix -= t
q += r
}
ix <<= 1
r >>= 1
}
// final rounding
if ix != 0 { // remainder, result not exact
q += q & 1 // round according to extra bit
}
ix = q>>1 + uint32(exp-1+bias)<<shift // significand + biased exponent
return math.Float32frombits(ix)
}

View File

@@ -0,0 +1,25 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !amd64 noasm appengine
package math32
import (
"math"
)
// Sqrt returns the square root of x.
//
// Special cases are:
// Sqrt(+Inf) = +Inf
// Sqrt(±0) = ±0
// Sqrt(x < 0) = NaN
// Sqrt(NaN) = NaN
func Sqrt(x float32) float32 {
// FIXME(kortschak): Direct translation of the math package
// asm code for 386 fails to build. No test hardware is available
// for arm, so using conversion instead.
return float32(math.Sqrt(float64(x)))
}

View File

@@ -0,0 +1,20 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !noasm,!appengine
package math32
// Sqrt returns the square root of x.
//
// Special cases are:
// Sqrt(+Inf) = +Inf
// Sqrt(±0) = ±0
// Sqrt(x < 0) = NaN
// Sqrt(NaN) = NaN
func Sqrt(x float32) float32

View File

@@ -0,0 +1,20 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//+build !noasm,!appengine
// TODO(kortschak): use textflag.h after we drop Go 1.3 support
//#include "textflag.h"
// Don't insert stack check preamble.
#define NOSPLIT 4
// func Sqrt(x float32) float32
TEXT ·Sqrt(SB),NOSPLIT,$0
SQRTSS x+0(FP), X0
MOVSS X0, ret+8(FP)
RET

610
blas/native/level1double.go Normal file
View File

@@ -0,0 +1,610 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"math"
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f64"
)
var _ blas.Float64Level1 = Implementation{}
// Dnrm2 computes the Euclidean norm of a vector,
// sqrt(\sum_i x[i] * x[i]).
// This function returns 0 if incX is negative.
func (Implementation) Dnrm2(n int, x []float64, incX int) float64 {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if n < 2 {
if n == 1 {
return math.Abs(x[0])
}
if n == 0 {
return 0
}
if n < 1 {
panic(negativeN)
}
}
var (
scale float64 = 0
sumSquares float64 = 1
)
if incX == 1 {
x = x[:n]
for _, v := range x {
if v == 0 {
continue
}
absxi := math.Abs(v)
if math.IsNaN(absxi) {
return math.NaN()
}
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
if math.IsInf(scale, 1) {
return math.Inf(1)
}
return scale * math.Sqrt(sumSquares)
}
for ix := 0; ix < n*incX; ix += incX {
val := x[ix]
if val == 0 {
continue
}
absxi := math.Abs(val)
if math.IsNaN(absxi) {
return math.NaN()
}
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
if math.IsInf(scale, 1) {
return math.Inf(1)
}
return scale * math.Sqrt(sumSquares)
}
// Dasum computes the sum of the absolute values of the elements of x.
// \sum_i |x[i]|
// Dasum returns 0 if incX is negative.
func (Implementation) Dasum(n int, x []float64, incX int) float64 {
var sum float64
if n < 0 {
panic(negativeN)
}
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if incX == 1 {
x = x[:n]
for _, v := range x {
sum += math.Abs(v)
}
return sum
}
for i := 0; i < n; i++ {
sum += math.Abs(x[i*incX])
}
return sum
}
// Idamax returns the index of an element of x with the largest absolute value.
// If there are multiple such indices the earliest is returned.
// Idamax returns -1 if n == 0.
func (Implementation) Idamax(n int, x []float64, incX int) int {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return -1
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if n < 2 {
if n == 1 {
return 0
}
if n == 0 {
return -1 // Netlib returns invalid index when n == 0
}
if n < 1 {
panic(negativeN)
}
}
idx := 0
max := math.Abs(x[0])
if incX == 1 {
for i, v := range x[:n] {
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
}
return idx
}
ix := incX
for i := 1; i < n; i++ {
v := x[ix]
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
ix += incX
}
return idx
}
// Dswap exchanges the elements of two vectors.
// x[i], y[i] = y[i], x[i] for all i
func (Implementation) Dswap(n int, x []float64, incX int, y []float64, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, v := range x {
x[i], y[i] = y[i], v
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
x[ix], y[iy] = y[iy], x[ix]
ix += incX
iy += incY
}
}
// Dcopy copies the elements of x into the elements of y.
// y[i] = x[i] for all i
func (Implementation) Dcopy(n int, x []float64, incX int, y []float64, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
copy(y[:n], x[:n])
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
y[iy] = x[ix]
ix += incX
iy += incY
}
}
// Daxpy adds alpha times x to y
// y[i] += alpha * x[i] for all i
func (Implementation) Daxpy(n int, alpha float64, x []float64, incX int, y []float64, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if alpha == 0 {
return
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
f64.AxpyUnitaryTo(y, alpha, x[:n], y)
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
f64.AxpyInc(alpha, x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}
// Drotg computes the plane rotation
// _ _ _ _ _ _
// | c s | | a | | r |
// | -s c | * | b | = | 0 |
// ‾ ‾ ‾ ‾ ‾ ‾
// where
// r = ±√(a^2 + b^2)
// c = a/r, the cosine of the plane rotation
// s = b/r, the sine of the plane rotation
//
// NOTE: There is a discrepancy between the refence implementation and the BLAS
// technical manual regarding the sign for r when a or b are zero.
// Drotg agrees with the definition in the manual and other
// common BLAS implementations.
func (Implementation) Drotg(a, b float64) (c, s, r, z float64) {
if b == 0 && a == 0 {
return 1, 0, a, 0
}
absA := math.Abs(a)
absB := math.Abs(b)
aGTb := absA > absB
r = math.Hypot(a, b)
if aGTb {
r = math.Copysign(r, a)
} else {
r = math.Copysign(r, b)
}
c = a / r
s = b / r
if aGTb {
z = s
} else if c != 0 { // r == 0 case handled above
z = 1 / c
} else {
z = 1
}
return
}
// Drotmg computes the modified Givens rotation. See
// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
// for more details.
func (Implementation) Drotmg(d1, d2, x1, y1 float64) (p blas.DrotmParams, rd1, rd2, rx1 float64) {
var p1, p2, q1, q2, u float64
const (
gam = 4096.0
gamsq = 16777216.0
rgamsq = 5.9604645e-8
)
if d1 < 0 {
p.Flag = blas.Rescaling
return
}
p2 = d2 * y1
if p2 == 0 {
p.Flag = blas.Identity
rd1 = d1
rd2 = d2
rx1 = x1
return
}
p1 = d1 * x1
q2 = p2 * y1
q1 = p1 * x1
absQ1 := math.Abs(q1)
absQ2 := math.Abs(q2)
if absQ1 < absQ2 && q2 < 0 {
p.Flag = blas.Rescaling
return
}
if d1 == 0 {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1, rd2 = d2/u, d1/u
rx1 = y1 / u
return
}
// Now we know that d1 != 0, and d2 != 0. If d2 == 0, it would be caught
// when p2 == 0, and if d1 == 0, then it is caught above
if absQ1 > absQ2 {
p.H[1] = -y1 / x1
p.H[2] = p2 / p1
u = 1 - p.H[2]*p.H[1]
rd1 = d1
rd2 = d2
rx1 = x1
p.Flag = blas.OffDiagonal
// u must be greater than zero because |q1| > |q2|, so check from netlib
// is unnecessary
// This is left in for ease of comparison with complex routines
//if u > 0 {
rd1 /= u
rd2 /= u
rx1 *= u
//}
} else {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1 = d2 / u
rd2 = d1 / u
rx1 = y1 * u
}
for rd1 <= rgamsq || rd1 >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if rd1 <= rgamsq {
rd1 *= gam * gam
rx1 /= gam
p.H[0] /= gam
p.H[2] /= gam
} else {
rd1 /= gam * gam
rx1 *= gam
p.H[0] *= gam
p.H[2] *= gam
}
}
for math.Abs(rd2) <= rgamsq || math.Abs(rd2) >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if math.Abs(rd2) <= rgamsq {
rd2 *= gam * gam
p.H[1] /= gam
p.H[3] /= gam
} else {
rd2 /= gam * gam
p.H[1] *= gam
p.H[3] *= gam
}
}
return
}
// Drot applies a plane transformation.
// x[i] = c * x[i] + s * y[i]
// y[i] = c * y[i] - s * x[i]
func (Implementation) Drot(n int, x []float64, incX int, y []float64, incY int, c float64, s float64) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = c*vx+s*vy, c*vy-s*vx
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = c*vx+s*vy, c*vy-s*vx
ix += incX
iy += incY
}
}
// Drotm applies the modified Givens rotation to the 2×n matrix.
func (Implementation) Drotm(n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
var h11, h12, h21, h22 float64
var ix, iy int
switch p.Flag {
case blas.Identity:
return
case blas.Rescaling:
h11 = p.H[0]
h12 = p.H[2]
h21 = p.H[1]
h22 = p.H[3]
case blas.OffDiagonal:
h11 = 1
h12 = p.H[2]
h21 = p.H[1]
h22 = 1
case blas.Diagonal:
h11 = p.H[0]
h12 = 1
h21 = -1
h22 = p.H[3]
}
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = vx*h11+vy*h12, vx*h21+vy*h22
}
return
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = vx*h11+vy*h12, vx*h21+vy*h22
ix += incX
iy += incY
}
return
}
// Dscal scales x by alpha.
// x[i] *= alpha
// Dscal has no effect if incX < 0.
func (Implementation) Dscal(n int, alpha float64, x []float64, incX int) {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return
}
if (n-1)*incX >= len(x) {
panic(badX)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if alpha == 0 {
if incX == 1 {
x = x[:n]
for i := range x {
x[i] = 0
}
return
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] = 0
}
return
}
if incX == 1 {
f64.ScalUnitary(alpha, x[:n])
return
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] *= alpha
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,49 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/internal/asm/f64"
)
// Ddot computes the dot product of the two vectors
// \sum_i x[i]*y[i]
func (Implementation) Ddot(n int, x []float64, incX int, y []float64, incY int) float64 {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return 0
}
panic(negativeN)
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
return f64.DotUnitary(x[:n], y)
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
return f64.DotInc(x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}

View File

@@ -0,0 +1,61 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
var impl Implementation
func TestDasum(t *testing.T) {
testblas.DasumTest(t, impl)
}
func TestDaxpy(t *testing.T) {
testblas.DaxpyTest(t, impl)
}
func TestDdot(t *testing.T) {
testblas.DdotTest(t, impl)
}
func TestDnrm2(t *testing.T) {
testblas.Dnrm2Test(t, impl)
}
func TestIdamax(t *testing.T) {
testblas.IdamaxTest(t, impl)
}
func TestDswap(t *testing.T) {
testblas.DswapTest(t, impl)
}
func TestDcopy(t *testing.T) {
testblas.DcopyTest(t, impl)
}
func TestDrotg(t *testing.T) {
testblas.DrotgTest(t, impl)
}
func TestDrotmg(t *testing.T) {
testblas.DrotmgTest(t, impl)
}
func TestDrot(t *testing.T) {
testblas.DrotTest(t, impl)
}
func TestDrotm(t *testing.T) {
testblas.DrotmTest(t, impl)
}
func TestDscal(t *testing.T) {
testblas.DscalTest(t, impl)
}

634
blas/native/level1single.go Normal file
View File

@@ -0,0 +1,634 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
math "github.com/gonum/blas/native/internal/math32"
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f32"
)
var _ blas.Float32Level1 = Implementation{}
// Snrm2 computes the Euclidean norm of a vector,
// sqrt(\sum_i x[i] * x[i]).
// This function returns 0 if incX is negative.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Snrm2(n int, x []float32, incX int) float32 {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if n < 2 {
if n == 1 {
return math.Abs(x[0])
}
if n == 0 {
return 0
}
if n < 1 {
panic(negativeN)
}
}
var (
scale float32 = 0
sumSquares float32 = 1
)
if incX == 1 {
x = x[:n]
for _, v := range x {
if v == 0 {
continue
}
absxi := math.Abs(v)
if math.IsNaN(absxi) {
return math.NaN()
}
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
if math.IsInf(scale, 1) {
return math.Inf(1)
}
return scale * math.Sqrt(sumSquares)
}
for ix := 0; ix < n*incX; ix += incX {
val := x[ix]
if val == 0 {
continue
}
absxi := math.Abs(val)
if math.IsNaN(absxi) {
return math.NaN()
}
if scale < absxi {
sumSquares = 1 + sumSquares*(scale/absxi)*(scale/absxi)
scale = absxi
} else {
sumSquares = sumSquares + (absxi/scale)*(absxi/scale)
}
}
if math.IsInf(scale, 1) {
return math.Inf(1)
}
return scale * math.Sqrt(sumSquares)
}
// Sasum computes the sum of the absolute values of the elements of x.
// \sum_i |x[i]|
// Sasum returns 0 if incX is negative.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sasum(n int, x []float32, incX int) float32 {
var sum float32
if n < 0 {
panic(negativeN)
}
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return 0
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if incX == 1 {
x = x[:n]
for _, v := range x {
sum += math.Abs(v)
}
return sum
}
for i := 0; i < n; i++ {
sum += math.Abs(x[i*incX])
}
return sum
}
// Isamax returns the index of an element of x with the largest absolute value.
// If there are multiple such indices the earliest is returned.
// Isamax returns -1 if n == 0.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Isamax(n int, x []float32, incX int) int {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return -1
}
if incX > 0 && (n-1)*incX >= len(x) {
panic(badX)
}
if n < 2 {
if n == 1 {
return 0
}
if n == 0 {
return -1 // Netlib returns invalid index when n == 0
}
if n < 1 {
panic(negativeN)
}
}
idx := 0
max := math.Abs(x[0])
if incX == 1 {
for i, v := range x[:n] {
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
}
return idx
}
ix := incX
for i := 1; i < n; i++ {
v := x[ix]
absV := math.Abs(v)
if absV > max {
max = absV
idx = i
}
ix += incX
}
return idx
}
// Sswap exchanges the elements of two vectors.
// x[i], y[i] = y[i], x[i] for all i
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sswap(n int, x []float32, incX int, y []float32, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, v := range x {
x[i], y[i] = y[i], v
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
x[ix], y[iy] = y[iy], x[ix]
ix += incX
iy += incY
}
}
// Scopy copies the elements of x into the elements of y.
// y[i] = x[i] for all i
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Scopy(n int, x []float32, incX int, y []float32, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
copy(y[:n], x[:n])
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
y[iy] = x[ix]
ix += incX
iy += incY
}
}
// Saxpy adds alpha times x to y
// y[i] += alpha * x[i] for all i
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Saxpy(n int, alpha float32, x []float32, incX int, y []float32, incY int) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if alpha == 0 {
return
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
f32.AxpyUnitaryTo(y, alpha, x[:n], y)
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
f32.AxpyInc(alpha, x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}
// Srotg computes the plane rotation
// _ _ _ _ _ _
// | c s | | a | | r |
// | -s c | * | b | = | 0 |
// ‾ ‾ ‾ ‾ ‾ ‾
// where
// r = ±√(a^2 + b^2)
// c = a/r, the cosine of the plane rotation
// s = b/r, the sine of the plane rotation
//
// NOTE: There is a discrepancy between the refence implementation and the BLAS
// technical manual regarding the sign for r when a or b are zero.
// Srotg agrees with the definition in the manual and other
// common BLAS implementations.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Srotg(a, b float32) (c, s, r, z float32) {
if b == 0 && a == 0 {
return 1, 0, a, 0
}
absA := math.Abs(a)
absB := math.Abs(b)
aGTb := absA > absB
r = math.Hypot(a, b)
if aGTb {
r = math.Copysign(r, a)
} else {
r = math.Copysign(r, b)
}
c = a / r
s = b / r
if aGTb {
z = s
} else if c != 0 { // r == 0 case handled above
z = 1 / c
} else {
z = 1
}
return
}
// Srotmg computes the modified Givens rotation. See
// http://www.netlib.org/lapack/explore-html/df/deb/drotmg_8f.html
// for more details.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Srotmg(d1, d2, x1, y1 float32) (p blas.SrotmParams, rd1, rd2, rx1 float32) {
var p1, p2, q1, q2, u float32
const (
gam = 4096.0
gamsq = 16777216.0
rgamsq = 5.9604645e-8
)
if d1 < 0 {
p.Flag = blas.Rescaling
return
}
p2 = d2 * y1
if p2 == 0 {
p.Flag = blas.Identity
rd1 = d1
rd2 = d2
rx1 = x1
return
}
p1 = d1 * x1
q2 = p2 * y1
q1 = p1 * x1
absQ1 := math.Abs(q1)
absQ2 := math.Abs(q2)
if absQ1 < absQ2 && q2 < 0 {
p.Flag = blas.Rescaling
return
}
if d1 == 0 {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1, rd2 = d2/u, d1/u
rx1 = y1 / u
return
}
// Now we know that d1 != 0, and d2 != 0. If d2 == 0, it would be caught
// when p2 == 0, and if d1 == 0, then it is caught above
if absQ1 > absQ2 {
p.H[1] = -y1 / x1
p.H[2] = p2 / p1
u = 1 - p.H[2]*p.H[1]
rd1 = d1
rd2 = d2
rx1 = x1
p.Flag = blas.OffDiagonal
// u must be greater than zero because |q1| > |q2|, so check from netlib
// is unnecessary
// This is left in for ease of comparison with complex routines
//if u > 0 {
rd1 /= u
rd2 /= u
rx1 *= u
//}
} else {
p.Flag = blas.Diagonal
p.H[0] = p1 / p2
p.H[3] = x1 / y1
u = 1 + p.H[0]*p.H[3]
rd1 = d2 / u
rd2 = d1 / u
rx1 = y1 * u
}
for rd1 <= rgamsq || rd1 >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if rd1 <= rgamsq {
rd1 *= gam * gam
rx1 /= gam
p.H[0] /= gam
p.H[2] /= gam
} else {
rd1 /= gam * gam
rx1 *= gam
p.H[0] *= gam
p.H[2] *= gam
}
}
for math.Abs(rd2) <= rgamsq || math.Abs(rd2) >= gamsq {
if p.Flag == blas.OffDiagonal {
p.H[0] = 1
p.H[3] = 1
p.Flag = blas.Rescaling
} else if p.Flag == blas.Diagonal {
p.H[1] = -1
p.H[2] = 1
p.Flag = blas.Rescaling
}
if math.Abs(rd2) <= rgamsq {
rd2 *= gam * gam
p.H[1] /= gam
p.H[3] /= gam
} else {
rd2 /= gam * gam
p.H[1] *= gam
p.H[3] *= gam
}
}
return
}
// Srot applies a plane transformation.
// x[i] = c * x[i] + s * y[i]
// y[i] = c * y[i] - s * x[i]
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Srot(n int, x []float32, incX int, y []float32, incY int, c float32, s float32) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = c*vx+s*vy, c*vy-s*vx
}
return
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = c*vx+s*vy, c*vy-s*vx
ix += incX
iy += incY
}
}
// Srotm applies the modified Givens rotation to the 2×n matrix.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Srotm(n int, x []float32, incX int, y []float32, incY int, p blas.SrotmParams) {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return
}
panic(negativeN)
}
if (incX > 0 && (n-1)*incX >= len(x)) || (incX < 0 && (1-n)*incX >= len(x)) {
panic(badX)
}
if (incY > 0 && (n-1)*incY >= len(y)) || (incY < 0 && (1-n)*incY >= len(y)) {
panic(badY)
}
var h11, h12, h21, h22 float32
var ix, iy int
switch p.Flag {
case blas.Identity:
return
case blas.Rescaling:
h11 = p.H[0]
h12 = p.H[2]
h21 = p.H[1]
h22 = p.H[3]
case blas.OffDiagonal:
h11 = 1
h12 = p.H[2]
h21 = p.H[1]
h22 = 1
case blas.Diagonal:
h11 = p.H[0]
h12 = 1
h21 = -1
h22 = p.H[3]
}
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if incX == 1 && incY == 1 {
x = x[:n]
for i, vx := range x {
vy := y[i]
x[i], y[i] = vx*h11+vy*h12, vx*h21+vy*h22
}
return
}
for i := 0; i < n; i++ {
vx := x[ix]
vy := y[iy]
x[ix], y[iy] = vx*h11+vy*h12, vx*h21+vy*h22
ix += incX
iy += incY
}
return
}
// Sscal scales x by alpha.
// x[i] *= alpha
// Sscal has no effect if incX < 0.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sscal(n int, alpha float32, x []float32, incX int) {
if incX < 1 {
if incX == 0 {
panic(zeroIncX)
}
return
}
if (n-1)*incX >= len(x) {
panic(badX)
}
if n < 1 {
if n == 0 {
return
}
panic(negativeN)
}
if alpha == 0 {
if incX == 1 {
x = x[:n]
for i := range x {
x[i] = 0
}
return
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] = 0
}
return
}
if incX == 1 {
f32.ScalUnitary(alpha, x[:n])
return
}
for ix := 0; ix < n*incX; ix += incX {
x[ix] *= alpha
}
}

View File

@@ -0,0 +1,53 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/internal/asm/f32"
)
// Dsdot computes the dot product of the two vectors
// \sum_i x[i]*y[i]
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Dsdot(n int, x []float32, incX int, y []float32, incY int) float64 {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return 0
}
panic(negativeN)
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
return f32.DdotUnitary(x[:n], y)
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
return f32.DdotInc(x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}

View File

@@ -0,0 +1,53 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/internal/asm/f32"
)
// Sdot computes the dot product of the two vectors
// \sum_i x[i]*y[i]
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sdot(n int, x []float32, incX int, y []float32, incY int) float32 {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return 0
}
panic(negativeN)
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
return f32.DotUnitary(x[:n], y)
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
return f32.DotInc(x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy))
}

View File

@@ -0,0 +1,53 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/internal/asm/f32"
)
// Sdsdot computes the dot product of the two vectors plus a constant
// alpha + \sum_i x[i]*y[i]
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sdsdot(n int, alpha float32, x []float32, incX int, y []float32, incY int) float32 {
if incX == 0 {
panic(zeroIncX)
}
if incY == 0 {
panic(zeroIncY)
}
if n <= 0 {
if n == 0 {
return 0
}
panic(negativeN)
}
if incX == 1 && incY == 1 {
if len(x) < n {
panic(badLenX)
}
if len(y) < n {
panic(badLenY)
}
return alpha + float32(f32.DdotUnitary(x[:n], y))
}
var ix, iy int
if incX < 0 {
ix = (-n + 1) * incX
}
if incY < 0 {
iy = (-n + 1) * incY
}
if ix >= len(x) || ix+(n-1)*incX >= len(x) {
panic(badLenX)
}
if iy >= len(y) || iy+(n-1)*incY >= len(y) {
panic(badLenY)
}
return alpha + float32(f32.DdotInc(x, y, uintptr(n), uintptr(incX), uintptr(incY), uintptr(ix), uintptr(iy)))
}

2236
blas/native/level2double.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemv(t *testing.T) {
testblas.DgemvTest(t, impl)
}
func TestDger(t *testing.T) {
testblas.DgerTest(t, impl)
}
func TestDtxmv(t *testing.T) {
testblas.DtxmvTest(t, impl)
}
func TestDgbmv(t *testing.T) {
testblas.DgbmvTest(t, impl)
}
func TestDtbsv(t *testing.T) {
testblas.DtbsvTest(t, impl)
}
func TestDsbmv(t *testing.T) {
testblas.DsbmvTest(t, impl)
}
func TestDtbmv(t *testing.T) {
testblas.DtbmvTest(t, impl)
}
func TestDtrsv(t *testing.T) {
testblas.DtrsvTest(t, impl)
}
func TestDtrmv(t *testing.T) {
testblas.DtrmvTest(t, impl)
}
func TestDsymv(t *testing.T) {
testblas.DsymvTest(t, impl)
}
func TestDsyr(t *testing.T) {
testblas.DsyrTest(t, impl)
}
func TestDsyr2(t *testing.T) {
testblas.Dsyr2Test(t, impl)
}
func TestDspr2(t *testing.T) {
testblas.Dspr2Test(t, impl)
}
func TestDspr(t *testing.T) {
testblas.DsprTest(t, impl)
}
func TestDspmv(t *testing.T) {
testblas.DspmvTest(t, impl)
}
func TestDtpsv(t *testing.T) {
testblas.DtpsvTest(t, impl)
}
func TestDtpmv(t *testing.T) {
testblas.DtpmvTest(t, impl)
}

2270
blas/native/level2single.go Normal file

File diff suppressed because it is too large Load Diff

831
blas/native/level3double.go Normal file
View File

@@ -0,0 +1,831 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f64"
)
var _ blas.Float64Level3 = Implementation{}
// Dtrsm solves
// A * X = alpha * B, if tA == blas.NoTrans side == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans side == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
// where A is an n×n or m×m triangular matrix, X is an m×n matrix, and alpha is a
// scalar.
//
// At entry to the function, X contains the values of B, and the result is
// stored in place into X.
//
// No check is made that A is invertible.
func (Implementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if ldb < n {
panic(badLdB)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if m == 0 || n == 0 {
return
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := m - 1; i >= 0; i-- {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := range btmp {
btmp[j] *= alpha
}
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
if va != 0 {
f64.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, va := range a[i*lda : i*lda+i] {
if va != 0 {
f64.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
// Cases where a is transposed
if ul == blas.Upper {
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
f64.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for i, va := range a[k*lda : k*lda+k] {
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
f64.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
// Cases where a is to the right of X.
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, vb := range btmp {
if vb != 0 {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
btmpk := btmp[k+1 : n]
f64.AxpyUnitaryTo(btmpk, -btmp[k], a[k*lda+k+1:k*lda+n], btmpk)
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k := n - 1; k >= 0; k-- {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
f64.AxpyUnitaryTo(btmp, -btmp[k], a[k*lda:k*lda+k], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := n - 1; j >= 0; j-- {
tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:])
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := 0; j < n; j++ {
tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda:j*lda+j], btmp)
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
}
// Dsymm performs one of
// C = alpha * A * B + beta * C, if side == blas.Left,
// C = alpha * B * A + beta * C, if side == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and alpha
// is a scalar.
func (Implementation) Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if s != blas.Right && s != blas.Left {
panic("goblas: bad side")
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if ldc*(m-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if m == 0 || n == 0 {
return
}
if alpha == 0 && beta == 1 {
return
}
if alpha == 0 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := 0; j < n; j++ {
ctmp[j] *= beta
}
}
return
}
isUpper := ul == blas.Upper
if s == blas.Left {
for i := 0; i < m; i++ {
atmp := alpha * a[i*lda+i]
btmp := b[i*ldb : i*ldb+n]
ctmp := c[i*ldc : i*ldc+n]
for j, v := range btmp {
ctmp[j] *= beta
ctmp[j] += atmp * v
}
for k := 0; k < i; k++ {
var atmp float64
if isUpper {
atmp = a[k*lda+i]
} else {
atmp = a[i*lda+k]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
f64.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
}
for k := i + 1; k < m; k++ {
var atmp float64
if isUpper {
atmp = a[i*lda+k]
} else {
atmp = a[k*lda+i]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
f64.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
}
}
return
}
if isUpper {
for i := 0; i < m; i++ {
for j := n - 1; j >= 0; j-- {
tmp := alpha * b[i*ldb+j]
var tmp2 float64
atmp := a[j*lda+j+1 : j*lda+n]
btmp := b[i*ldb+j+1 : i*ldb+n]
ctmp := c[i*ldc+j+1 : i*ldc+n]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
return
}
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
tmp := alpha * b[i*ldb+j]
var tmp2 float64
atmp := a[j*lda : j*lda+j]
btmp := b[i*ldb : i*ldb+j]
ctmp := c[i*ldc : i*ldc+j]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
}
// Dsyrk performs the symmetric rank-k operation
// C = alpha * A * A^T + beta*C
// C is an n×n symmetric matrix. A is an n×k matrix if tA == blas.NoTrans, and
// a k×n matrix otherwise. alpha and beta are scalars.
func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
var row, col int
if tA == blas.NoTrans {
row, col = n, k
} else {
row, col = k, n
}
if lda*(row-1)+col > len(a) || lda < max(1, col) {
panic(badLdA)
}
if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
atmp := a[i*lda : i*lda+k]
for jc, vc := range ctmp {
j := jc + i
ctmp[jc] = vc*beta + alpha*f64.DotUnitary(atmp, a[j*lda:j*lda+k])
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
for j, vc := range c[i*ldc : i*ldc+i+1] {
c[i*ldc+j] = vc*beta + alpha*f64.DotUnitary(a[j*lda:j*lda+k], atmp)
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
f64.AxpyUnitaryTo(ctmp, tmp, a[l*lda+i:l*lda+n], ctmp)
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 0 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
f64.AxpyUnitaryTo(ctmp, tmp, a[l*lda:l*lda+i+1], ctmp)
}
}
}
}
// Dsyr2k performs the symmetric rank 2k operation
// C = alpha * A * B^T + alpha * B * A^T + beta * C
// where C is an n×n symmetric matrix. A and B are n×k matrices if
// tA == NoTrans and k×n otherwise. alpha and beta are scalars.
func (Implementation) Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
var row, col int
if tA == blas.NoTrans {
row, col = n, k
} else {
row, col = k, n
}
if lda*(row-1)+col > len(a) || lda < max(1, col) {
panic(badLdA)
}
if ldb*(row-1)+col > len(b) || ldb < max(1, col) {
panic(badLdB)
}
if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*ldb : i*ldb+k]
ctmp := c[i*ldc+i : i*ldc+n]
for jc := range ctmp {
j := i + jc
var tmp1, tmp2 float64
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[jc] *= beta
ctmp[jc] += alpha * (tmp1 + tmp2)
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*ldb : i*ldb+k]
ctmp := c[i*ldc : i*ldc+i+1]
for j := 0; j <= i; j++ {
var tmp1, tmp2 float64
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[j] *= beta
ctmp[j] += alpha * (tmp1 + tmp2)
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb+i : l*ldb+n]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda+i : l*lda+n] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb : l*ldb+i+1]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda : l*lda+i+1] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
}
// Dtrmm performs
// B = alpha * A * B, if tA == blas.NoTrans and side == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and side == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
// where A is an n×n or m×m triangular matrix, and B is an m×n matrix.
func (Implementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
tmp := alpha * va
if tmp != 0 {
f64.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
}
}
}
return
}
for i := m - 1; i >= 0; i-- {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for k, va := range a[i*lda : i*lda+i] {
tmp := alpha * va
if tmp != 0 {
f64.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
f64.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
for i, va := range a[k*lda : k*lda+k] {
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
f64.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
// Cases where a is on the right
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := n - 1; k >= 0; k-- {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
for ja, v := range a[k*lda+k+1 : k*lda+n] {
j := ja + k + 1
btmp[j] += tmp * v
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := 0; k < n; k++ {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
f64.AxpyUnitaryTo(btmp, tmp, a[k*lda:k*lda+k], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j, vb := range btmp {
tmp := vb
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:n])
btmp[j] = alpha * tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := n - 1; j >= 0; j-- {
tmp := btmp[j]
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += f64.DotUnitary(a[j*lda:j*lda+j], btmp[:j])
btmp[j] = alpha * tmp
}
}
}

View File

@@ -0,0 +1,35 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"testing"
"github.com/gonum/blas/testblas"
)
func TestDgemm(t *testing.T) {
testblas.TestDgemm(t, impl)
}
func TestDsymm(t *testing.T) {
testblas.DsymmTest(t, impl)
}
func TestDtrsm(t *testing.T) {
testblas.DtrsmTest(t, impl)
}
func TestDsyrk(t *testing.T) {
testblas.DsyrkTest(t, impl)
}
func TestDsyr2k(t *testing.T) {
testblas.Dsyr2kTest(t, impl)
}
func TestDtrmm(t *testing.T) {
testblas.DtrmmTest(t, impl)
}

843
blas/native/level3single.go Normal file
View File

@@ -0,0 +1,843 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f32"
)
var _ blas.Float32Level3 = Implementation{}
// Strsm solves
// A * X = alpha * B, if tA == blas.NoTrans side == blas.Left,
// A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
// X * A = alpha * B, if tA == blas.NoTrans side == blas.Right,
// X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
// where A is an n×n or m×m triangular matrix, X is an m×n matrix, and alpha is a
// scalar.
//
// At entry to the function, X contains the values of B, and the result is
// stored in place into X.
//
// No check is made that A is invertible.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Strsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
if ldb < n {
panic(badLdB)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if m == 0 || n == 0 {
return
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := m - 1; i >= 0; i-- {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := range btmp {
btmp[j] *= alpha
}
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
if va != 0 {
f32.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, va := range a[i*lda : i*lda+i] {
if va != 0 {
f32.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
}
}
if nonUnit {
tmp := 1 / a[i*lda+i]
for j := 0; j < n; j++ {
btmp[j] *= tmp
}
}
}
return
}
// Cases where a is transposed
if ul == blas.Upper {
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
f32.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
if nonUnit {
tmp := 1 / a[k*lda+k]
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
for i, va := range a[k*lda : k*lda+k] {
if va != 0 {
btmp := b[i*ldb : i*ldb+n]
f32.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
}
}
if alpha != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= alpha
}
}
}
return
}
// Cases where a is to the right of X.
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k, vb := range btmp {
if vb != 0 {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
btmpk := btmp[k+1 : n]
f32.AxpyUnitaryTo(btmpk, -btmp[k], a[k*lda+k+1:k*lda+n], btmpk)
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
if alpha != 1 {
for j := 0; j < n; j++ {
btmp[j] *= alpha
}
}
for k := n - 1; k >= 0; k-- {
if btmp[k] != 0 {
if nonUnit {
btmp[k] /= a[k*lda+k]
}
f32.AxpyUnitaryTo(btmp, -btmp[k], a[k*lda:k*lda+k], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := n - 1; j >= 0; j-- {
tmp := alpha*btmp[j] - f32.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:])
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*lda : i*lda+n]
for j := 0; j < n; j++ {
tmp := alpha*btmp[j] - f32.DotUnitary(a[j*lda:j*lda+j], btmp)
if nonUnit {
tmp /= a[j*lda+j]
}
btmp[j] = tmp
}
}
}
// Ssymm performs one of
// C = alpha * A * B + beta * C, if side == blas.Left,
// C = alpha * B * A + beta * C, if side == blas.Right,
// where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and alpha
// is a scalar.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Ssymm(s blas.Side, ul blas.Uplo, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
if s != blas.Right && s != blas.Left {
panic("goblas: bad side")
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if ldc*(m-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if m == 0 || n == 0 {
return
}
if alpha == 0 && beta == 1 {
return
}
if alpha == 0 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := 0; j < n; j++ {
ctmp[j] *= beta
}
}
return
}
isUpper := ul == blas.Upper
if s == blas.Left {
for i := 0; i < m; i++ {
atmp := alpha * a[i*lda+i]
btmp := b[i*ldb : i*ldb+n]
ctmp := c[i*ldc : i*ldc+n]
for j, v := range btmp {
ctmp[j] *= beta
ctmp[j] += atmp * v
}
for k := 0; k < i; k++ {
var atmp float32
if isUpper {
atmp = a[k*lda+i]
} else {
atmp = a[i*lda+k]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
f32.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
}
for k := i + 1; k < m; k++ {
var atmp float32
if isUpper {
atmp = a[i*lda+k]
} else {
atmp = a[k*lda+i]
}
atmp *= alpha
ctmp := c[i*ldc : i*ldc+n]
f32.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
}
}
return
}
if isUpper {
for i := 0; i < m; i++ {
for j := n - 1; j >= 0; j-- {
tmp := alpha * b[i*ldb+j]
var tmp2 float32
atmp := a[j*lda+j+1 : j*lda+n]
btmp := b[i*ldb+j+1 : i*ldb+n]
ctmp := c[i*ldc+j+1 : i*ldc+n]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
return
}
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
tmp := alpha * b[i*ldb+j]
var tmp2 float32
atmp := a[j*lda : j*lda+j]
btmp := b[i*ldb : i*ldb+j]
ctmp := c[i*ldc : i*ldc+j]
for k, v := range atmp {
ctmp[k] += tmp * v
tmp2 += btmp[k] * v
}
c[i*ldc+j] *= beta
c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
}
}
}
// Ssyrk performs the symmetric rank-k operation
// C = alpha * A * A^T + beta*C
// C is an n×n symmetric matrix. A is an n×k matrix if tA == blas.NoTrans, and
// a k×n matrix otherwise. alpha and beta are scalars.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Ssyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float32, a []float32, lda int, beta float32, c []float32, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
var row, col int
if tA == blas.NoTrans {
row, col = n, k
} else {
row, col = k, n
}
if lda*(row-1)+col > len(a) || lda < max(1, col) {
panic(badLdA)
}
if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
atmp := a[i*lda : i*lda+k]
for jc, vc := range ctmp {
j := jc + i
ctmp[jc] = vc*beta + alpha*f32.DotUnitary(atmp, a[j*lda:j*lda+k])
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
for j, vc := range c[i*ldc : i*ldc+i+1] {
c[i*ldc+j] = vc*beta + alpha*f32.DotUnitary(a[j*lda:j*lda+k], atmp)
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
f32.AxpyUnitaryTo(ctmp, tmp, a[l*lda+i:l*lda+n], ctmp)
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 0 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp := alpha * a[l*lda+i]
if tmp != 0 {
f32.AxpyUnitaryTo(ctmp, tmp, a[l*lda:l*lda+i+1], ctmp)
}
}
}
}
// Ssyr2k performs the symmetric rank 2k operation
// C = alpha * A * B^T + alpha * B * A^T + beta * C
// where C is an n×n symmetric matrix. A and B are n×k matrices if
// tA == NoTrans and k×n otherwise. alpha and beta are scalars.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Ssyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
panic(badTranspose)
}
if n < 0 {
panic(nLT0)
}
if k < 0 {
panic(kLT0)
}
if ldc < n {
panic(badLdC)
}
var row, col int
if tA == blas.NoTrans {
row, col = n, k
} else {
row, col = k, n
}
if lda*(row-1)+col > len(a) || lda < max(1, col) {
panic(badLdA)
}
if ldb*(row-1)+col > len(b) || ldb < max(1, col) {
panic(badLdB)
}
if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
panic(badLdC)
}
if alpha == 0 {
if beta == 0 {
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] = 0
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
for j := range ctmp {
ctmp[j] *= beta
}
}
return
}
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*ldb : i*ldb+k]
ctmp := c[i*ldc+i : i*ldc+n]
for jc := range ctmp {
j := i + jc
var tmp1, tmp2 float32
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[jc] *= beta
ctmp[jc] += alpha * (tmp1 + tmp2)
}
}
return
}
for i := 0; i < n; i++ {
atmp := a[i*lda : i*lda+k]
btmp := b[i*ldb : i*ldb+k]
ctmp := c[i*ldc : i*ldc+i+1]
for j := 0; j <= i; j++ {
var tmp1, tmp2 float32
binner := b[j*ldb : j*ldb+k]
for l, v := range a[j*lda : j*lda+k] {
tmp1 += v * btmp[l]
tmp2 += atmp[l] * binner[l]
}
ctmp[j] *= beta
ctmp[j] += alpha * (tmp1 + tmp2)
}
}
return
}
if ul == blas.Upper {
for i := 0; i < n; i++ {
ctmp := c[i*ldc+i : i*ldc+n]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb+i : l*ldb+n]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda+i : l*lda+n] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
return
}
for i := 0; i < n; i++ {
ctmp := c[i*ldc : i*ldc+i+1]
if beta != 1 {
for j := range ctmp {
ctmp[j] *= beta
}
}
for l := 0; l < k; l++ {
tmp1 := alpha * b[l*lda+i]
tmp2 := alpha * a[l*lda+i]
btmp := b[l*ldb : l*ldb+i+1]
if tmp1 != 0 || tmp2 != 0 {
for j, v := range a[l*lda : l*lda+i+1] {
ctmp[j] += v*tmp1 + btmp[j]*tmp2
}
}
}
}
}
// Strmm performs
// B = alpha * A * B, if tA == blas.NoTrans and side == blas.Left,
// B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
// B = alpha * B * A, if tA == blas.NoTrans and side == blas.Right,
// B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
// where A is an n×n or m×m triangular matrix, and B is an m×n matrix.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Strmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float32, a []float32, lda int, b []float32, ldb int) {
if s != blas.Left && s != blas.Right {
panic(badSide)
}
if ul != blas.Lower && ul != blas.Upper {
panic(badUplo)
}
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if d != blas.NonUnit && d != blas.Unit {
panic(badDiag)
}
if m < 0 {
panic(mLT0)
}
if n < 0 {
panic(nLT0)
}
var k int
if s == blas.Left {
k = m
} else {
k = n
}
if lda*(k-1)+k > len(a) || lda < max(1, k) {
panic(badLdA)
}
if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
panic(badLdB)
}
if alpha == 0 {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] = 0
}
}
return
}
nonUnit := d == blas.NonUnit
if s == blas.Left {
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for ka, va := range a[i*lda+i+1 : i*lda+m] {
k := ka + i + 1
tmp := alpha * va
if tmp != 0 {
f32.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
}
}
}
return
}
for i := m - 1; i >= 0; i-- {
tmp := alpha
if nonUnit {
tmp *= a[i*lda+i]
}
btmp := b[i*ldb : i*ldb+n]
for j := range btmp {
btmp[j] *= tmp
}
for k, va := range a[i*lda : i*lda+i] {
tmp := alpha * va
if tmp != 0 {
f32.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for k := m - 1; k >= 0; k-- {
btmpk := b[k*ldb : k*ldb+n]
for ia, va := range a[k*lda+k+1 : k*lda+m] {
i := ia + k + 1
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
f32.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
for k := 0; k < m; k++ {
btmpk := b[k*ldb : k*ldb+n]
for i, va := range a[k*lda : k*lda+k] {
btmp := b[i*ldb : i*ldb+n]
tmp := alpha * va
if tmp != 0 {
f32.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
}
}
tmp := alpha
if nonUnit {
tmp *= a[k*lda+k]
}
if tmp != 1 {
for j := 0; j < n; j++ {
btmpk[j] *= tmp
}
}
}
return
}
// Cases where a is on the right
if tA == blas.NoTrans {
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := n - 1; k >= 0; k-- {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
for ja, v := range a[k*lda+k+1 : k*lda+n] {
j := ja + k + 1
btmp[j] += tmp * v
}
}
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for k := 0; k < n; k++ {
tmp := alpha * btmp[k]
if tmp != 0 {
btmp[k] = tmp
if nonUnit {
btmp[k] *= a[k*lda+k]
}
f32.AxpyUnitaryTo(btmp, tmp, a[k*lda:k*lda+k], btmp)
}
}
}
return
}
// Cases where a is transposed.
if ul == blas.Upper {
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j, vb := range btmp {
tmp := vb
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += f32.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:n])
btmp[j] = alpha * tmp
}
}
return
}
for i := 0; i < m; i++ {
btmp := b[i*ldb : i*ldb+n]
for j := n - 1; j >= 0; j-- {
tmp := btmp[j]
if nonUnit {
tmp *= a[j*lda+j]
}
tmp += f32.DotUnitary(a[j*lda:j*lda+j], btmp[:j])
btmp[j] = alpha * tmp
}
}
}

72
blas/native/native.go Normal file
View File

@@ -0,0 +1,72 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate ./single_precision.bash
package native
type Implementation struct{}
// The following are panic strings used during parameter checks.
const (
negativeN = "blas: n < 0"
zeroIncX = "blas: zero x index increment"
zeroIncY = "blas: zero y index increment"
badLenX = "blas: x index out of range"
badLenY = "blas: y index out of range"
mLT0 = "blas: m < 0"
nLT0 = "blas: n < 0"
kLT0 = "blas: k < 0"
kLLT0 = "blas: kL < 0"
kULT0 = "blas: kU < 0"
badUplo = "blas: illegal triangle"
badTranspose = "blas: illegal transpose"
badDiag = "blas: illegal diagonal"
badSide = "blas: illegal side"
badLdA = "blas: index of a out of range"
badLdB = "blas: index of b out of range"
badLdC = "blas: index of c out of range"
badX = "blas: x index out of range"
badY = "blas: y index out of range"
)
// [SD]gemm behavior constants. These are kept here to keep them out of the
// way during single precision code genration.
const (
blockSize = 64 // b x b matrix
minParBlock = 4 // minimum number of blocks needed to go parallel
buffMul = 4 // how big is the buffer relative to the number of workers
)
// [SD]gemm debugging constant.
const debug = false
// subMul is a common type shared by [SD]gemm.
type subMul struct {
i, j int // index of block
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func min(a, b int) int {
if a > b {
return b
}
return a
}
// blocks returns the number of divisons of the dimension length with the given
// block size.
func blocks(dim, bsize int) int {
return (dim + bsize - 1) / bsize
}

View File

@@ -0,0 +1,181 @@
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func TestDgemmParallel(t *testing.T) {
for i, test := range []struct {
m int
n int
k int
alpha float64
tA blas.Transpose
tB blas.Transpose
}{
{
m: 3,
n: 4,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*2 + 5,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * 2,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize*3 - 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize * minParBlock,
k: 2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 2,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize*minParBlock + 1,
n: blockSize * minParBlock,
k: 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: 3,
n: blockSize*minParBlock + 2,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: 3,
k: blockSize * minParBlock,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize * minParBlock,
n: blockSize * minParBlock,
k: blockSize * 3,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
{
m: blockSize + blockSize/2,
n: blockSize + blockSize/2,
k: blockSize + blockSize/2,
alpha: 2.5,
tA: blas.NoTrans,
tB: blas.NoTrans,
},
} {
testMatchParallelSerial(t, i, blas.NoTrans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.NoTrans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.NoTrans, blas.Trans, test.m, test.n, test.k, test.alpha)
testMatchParallelSerial(t, i, blas.Trans, blas.Trans, test.m, test.n, test.k, test.alpha)
}
}
func testMatchParallelSerial(t *testing.T, i int, tA, tB blas.Transpose, m, n, k int, alpha float64) {
var (
rowA, colA int
rowB, colB int
)
if tA == blas.NoTrans {
rowA = m
colA = k
} else {
rowA = k
colA = m
}
if tB == blas.NoTrans {
rowB = k
colB = n
} else {
rowB = n
colB = k
}
a := randmat(rowA, colA, colA)
b := randmat(rowB, colB, colB)
c := randmat(m, n, n)
aClone := a.clone()
bClone := b.clone()
cClone := c.clone()
lda := colA
ldb := colB
ldc := n
dgemmSerial(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, cClone.data, ldc, alpha)
dgemmParallel(tA == blas.Trans, tB == blas.Trans, m, n, k, a.data, lda, b.data, ldb, c.data, ldc, alpha)
if !a.equal(aClone) {
t.Errorf("Case %v: a changed during call to dgemmParallel", i)
}
if !b.equal(bClone) {
t.Errorf("Case %v: b changed during call to dgemmParallel", i)
}
if !c.equalWithinAbs(cClone, 1e-12) {
t.Errorf("Case %v: answer not equal parallel and serial", i)
}
}
func randmat(r, c, stride int) general64 {
data := make([]float64, r*stride+c)
for i := range data {
data[i] = rand.Float64()
}
return general64{
data: data,
rows: r,
cols: c,
stride: stride,
}
}

280
blas/native/sgemm.go Normal file
View File

@@ -0,0 +1,280 @@
// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.
// Copyright ©2014 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package native
import (
"runtime"
"sync"
"github.com/gonum/blas"
"github.com/gonum/internal/asm/f32"
)
// Sgemm computes
// C = beta * C + alpha * A * B,
// where A, B, and C are dense matrices, and alpha and beta are scalars.
// tA and tB specify whether A or B are transposed.
//
// Float32 implementations are autogenerated and not directly tested.
func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
panic(badTranspose)
}
if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
panic(badTranspose)
}
aTrans := tA == blas.Trans || tA == blas.ConjTrans
if aTrans {
checkMatrix32(k, m, a, lda)
} else {
checkMatrix32(m, k, a, lda)
}
bTrans := tB == blas.Trans || tB == blas.ConjTrans
if bTrans {
checkMatrix32(n, k, b, ldb)
} else {
checkMatrix32(k, n, b, ldb)
}
checkMatrix32(m, n, c, ldc)
// scale c
if beta != 1 {
if beta == 0 {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] = 0
}
}
} else {
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for j := range ctmp {
ctmp[j] *= beta
}
}
}
}
sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
}
func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
// dgemmParallel computes a parallel matrix multiplication by partitioning
// a and b into sub-blocks, and updating c with the multiplication of the sub-block
// In all cases,
// A = [ A_11 A_12 ... A_1j
// A_21 A_22 ... A_2j
// ...
// A_i1 A_i2 ... A_ij]
//
// and same for B. All of the submatrix sizes are blockSize×blockSize except
// at the edges.
//
// In all cases, there is one dimension for each matrix along which
// C must be updated sequentially.
// Cij = \sum_k Aik Bki, (A * B)
// Cij = \sum_k Aki Bkj, (A^T * B)
// Cij = \sum_k Aik Bjk, (A * B^T)
// Cij = \sum_k Aki Bjk, (A^T * B^T)
//
// This code computes one {i, j} block sequentially along the k dimension,
// and computes all of the {i, j} blocks concurrently. This
// partitioning allows Cij to be updated in-place without race-conditions.
// Instead of launching a goroutine for each possible concurrent computation,
// a number of worker goroutines are created and channels are used to pass
// available and completed cases.
//
// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
// multiplies, though this code does not copy matrices to attempt to eliminate
// cache misses.
maxKLen := k
parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
if parBlocks < minParBlock {
// The matrix multiplication is small in the dimensions where it can be
// computed concurrently. Just do it in serial.
sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
}
nWorkers := runtime.GOMAXPROCS(0)
if parBlocks < nWorkers {
nWorkers = parBlocks
}
// There is a tradeoff between the workers having to wait for work
// and a large buffer making operations slow.
buf := buffMul * nWorkers
if buf > parBlocks {
buf = parBlocks
}
sendChan := make(chan subMul, buf)
// Launch workers. A worker receives an {i, j} submatrix of c, and computes
// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
// channel is finally closed, it signals to the waitgroup that it has finished
// computing.
var wg sync.WaitGroup
for i := 0; i < nWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Make local copies of otherwise global variables to reduce shared memory.
// This has a noticable effect on benchmarks in some cases.
alpha := alpha
aTrans := aTrans
bTrans := bTrans
m := m
n := n
for sub := range sendChan {
i := sub.i
j := sub.j
leni := blockSize
if i+leni > m {
leni = m - i
}
lenj := blockSize
if j+lenj > n {
lenj = n - j
}
cSub := sliceView32(c, ldc, i, j, leni, lenj)
// Compute A_ik B_kj for all k
for k := 0; k < maxKLen; k += blockSize {
lenk := blockSize
if k+lenk > maxKLen {
lenk = maxKLen - k
}
var aSub, bSub []float32
if aTrans {
aSub = sliceView32(a, lda, k, i, lenk, leni)
} else {
aSub = sliceView32(a, lda, i, k, leni, lenk)
}
if bTrans {
bSub = sliceView32(b, ldb, j, k, lenj, lenk)
} else {
bSub = sliceView32(b, ldb, k, j, lenk, lenj)
}
sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
}
}
}()
}
// Send out all of the {i, j} subblocks for computation.
for i := 0; i < m; i += blockSize {
for j := 0; j < n; j += blockSize {
sendChan <- subMul{
i: i,
j: j,
}
}
}
close(sendChan)
wg.Wait()
}
// sgemmSerial is serial matrix multiply
func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
switch {
case !aTrans && !bTrans:
sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case aTrans && !bTrans:
sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case !aTrans && bTrans:
sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
case aTrans && bTrans:
sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
return
default:
panic("unreachable")
}
}
// sgemmSerial where neither a nor b are transposed
func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < m; i++ {
ctmp := c[i*ldc : i*ldc+n]
for l, v := range a[i*lda : i*lda+k] {
tmp := alpha * v
if tmp != 0 {
f32.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
}
}
}
}
// sgemmSerial where neither a is transposed and b is not
func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < k; l++ {
btmp := b[l*ldb : l*ldb+n]
for i, v := range a[l*lda : l*lda+m] {
tmp := alpha * v
if tmp != 0 {
ctmp := c[i*ldc : i*ldc+n]
f32.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
}
}
}
}
// sgemmSerial where neither a is not transposed and b is
func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for i := 0; i < m; i++ {
atmp := a[i*lda : i*lda+k]
ctmp := c[i*ldc : i*ldc+n]
for j := 0; j < n; j++ {
ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
}
}
}
// sgemmSerial where both are transposed
func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
// This style is used instead of the literal [i*stride +j]) is used because
// approximately 5 times faster as of go 1.3.
for l := 0; l < k; l++ {
for i, v := range a[l*lda : l*lda+m] {
tmp := alpha * v
if tmp != 0 {
ctmp := c[i*ldc : i*ldc+n]
f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
}
}
}
}
func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
return a[i*lda+j : (i+r-1)*lda+j+c]
}
func checkMatrix32(m, n int, a []float32, lda int) {
if m < 0 {
panic("blas: rows < 0")
}
if n < 0 {
panic("blas: cols < 0")
}
if lda < n {
panic("blas: illegal stride")
}
if len(a) < (m-1)*lda+n {
panic("blas: insufficient matrix slice length")
}
}

156
blas/native/single_precision.bash Executable file
View File

@@ -0,0 +1,156 @@
#!/usr/bin/env bash
# Copyright ©2015 The gonum Authors. All rights reserved.
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file.
WARNING='//\
// Float32 implementations are autogenerated and not directly tested.\
'
# Level1 routines.
echo Generating level1single.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level1single.go
cat level1double.go \
| gofmt -r 'blas.Float64Level1 -> blas.Float32Level1' \
\
| gofmt -r 'float64 -> float32' \
| gofmt -r 'blas.DrotmParams -> blas.SrotmParams' \
\
| gofmt -r 'f64.AxpyInc -> f32.AxpyInc' \
| gofmt -r 'f64.AxpyIncTo -> f32.AxpyIncTo' \
| gofmt -r 'f64.AxpyUnitaryTo -> f32.AxpyUnitaryTo' \
| gofmt -r 'f64.DotUnitary -> f32.DotUnitary' \
| gofmt -r 'f64.ScalUnitary -> f32.ScalUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1S\2_" \
-e 's_^// D_// S_' \
-e "s_^\(func (Implementation) \)Id\(.*\)\$_$WARNING\1Is\2_" \
-e 's_^// Id_// Is_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
-e 's_"math"_math "github.com/gonum/blas/native/internal/math32"_' \
>> level1single.go
echo Generating level1single_sdot.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level1single_sdot.go
cat level1double_ddot.go \
| gofmt -r 'float64 -> float32' \
\
| gofmt -r 'f64.DotInc -> f32.DotInc' \
| gofmt -r 'f64.DotUnitary -> f32.DotUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1S\2_" \
-e 's_^// D_// S_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> level1single_sdot.go
echo Generating level1single_dsdot.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level1single_dsdot.go
cat level1double_ddot.go \
| gofmt -r '[]float64 -> []float32' \
\
| gofmt -r 'f64.DotInc -> f32.DdotInc' \
| gofmt -r 'f64.DotUnitary -> f32.DdotUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1Ds\2_" \
-e 's_^// D_// Ds_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> level1single_dsdot.go
echo Generating level1single_sdsdot.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level1single_sdsdot.go
cat level1double_ddot.go \
| gofmt -r 'float64 -> float32' \
\
| gofmt -r 'f64.DotInc(x, y, f(n), f(incX), f(incY), f(ix), f(iy)) -> alpha + float32(f32.DdotInc(x, y, f(n), f(incX), f(incY), f(ix), f(iy)))' \
| gofmt -r 'f64.DotUnitary(a, b) -> alpha + float32(f32.DdotUnitary(a, b))' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1Sds\2_" \
-e 's_^// D\(.*\)$_// Sds\1 plus a constant_' \
-e 's_\\sum_alpha + \\sum_' \
-e 's/n int/n int, alpha float32/' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> level1single_sdsdot.go
# Level2 routines.
echo Generating level2single.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level2single.go
cat level2double.go \
| gofmt -r 'blas.Float64Level2 -> blas.Float32Level2' \
\
| gofmt -r 'float64 -> float32' \
\
| gofmt -r 'Dscal -> Sscal' \
\
| gofmt -r 'f64.AxpyInc -> f32.AxpyInc' \
| gofmt -r 'f64.AxpyIncTo -> f32.AxpyIncTo' \
| gofmt -r 'f64.AxpyUnitary -> f32.AxpyUnitary' \
| gofmt -r 'f64.AxpyUnitaryTo -> f32.AxpyUnitaryTo' \
| gofmt -r 'f64.DotInc -> f32.DotInc' \
| gofmt -r 'f64.DotUnitary -> f32.DotUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1S\2_" \
-e 's_^// D_// S_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> level2single.go
# Level3 routines.
echo Generating level3single.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > level3single.go
cat level3double.go \
| gofmt -r 'blas.Float64Level3 -> blas.Float32Level3' \
\
| gofmt -r 'float64 -> float32' \
\
| gofmt -r 'f64.AxpyUnitaryTo -> f32.AxpyUnitaryTo' \
| gofmt -r 'f64.DotUnitary -> f32.DotUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1S\2_" \
-e 's_^// D_// S_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> level3single.go
echo Generating general_single.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > general_single.go
cat general_double.go \
| gofmt -r 'float64 -> float32' \
\
| gofmt -r 'general64 -> general32' \
| gofmt -r 'newGeneral64 -> newGeneral32' \
\
| sed -e 's/(g general64) print()/(g general32) print()/' \
-e 's_"math"_math "github.com/gonum/blas/native/internal/math32"_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> general_single.go
echo Generating sgemm.go
echo -e '// Code generated by "go generate github.com/gonum/blas/native"; DO NOT EDIT.\n' > sgemm.go
cat dgemm.go \
| gofmt -r 'float64 -> float32' \
| gofmt -r 'general64 -> general32' \
| gofmt -r 'sliceView64 -> sliceView32' \
| gofmt -r 'checkMatrix64 -> checkMatrix32' \
\
| gofmt -r 'dgemmParallel -> sgemmParallel' \
| gofmt -r 'computeNumBlocks64 -> computeNumBlocks32' \
| gofmt -r 'dgemmSerial -> sgemmSerial' \
| gofmt -r 'dgemmSerialNotNot -> sgemmSerialNotNot' \
| gofmt -r 'dgemmSerialTransNot -> sgemmSerialTransNot' \
| gofmt -r 'dgemmSerialNotTrans -> sgemmSerialNotTrans' \
| gofmt -r 'dgemmSerialTransTrans -> sgemmSerialTransTrans' \
\
| gofmt -r 'f64.AxpyInc -> f32.AxpyInc' \
| gofmt -r 'f64.AxpyIncTo -> f32.AxpyIncTo' \
| gofmt -r 'f64.AxpyUnitaryTo -> f32.AxpyUnitaryTo' \
| gofmt -r 'f64.DotUnitary -> f32.DotUnitary' \
\
| sed -e "s_^\(func (Implementation) \)D\(.*\)\$_$WARNING\1S\2_" \
-e 's_^// D_// S_' \
-e 's_^// d_// s_' \
-e 's_"github.com/gonum/internal/asm/f64"_"github.com/gonum/internal/asm/f32"_' \
>> sgemm.go

View File

@@ -0,0 +1,292 @@
// Copyright 2014 The Gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file
// Script for automatic code generation of the benchmark routines
package main
import (
"fmt"
"os"
"os/exec"
"path"
"path/filepath"
"strconv"
)
var gopath string
var copyrightnotice = []byte(`// Copyright 2014 The Gonum Authors. All rights reserved.
// Use of this code is governed by a BSD-style
// license that can be found in the LICENSE file`)
var autogen = []byte("// Code generated by \"go run github.com/gonum/blas/testblas/benchautogen/autogen_bench_level1double.go\"; DO NOT EDIT.\n")
var imports = []byte(`import(
"math/rand"
"testing"
"github.com/gonum/blas"
)`)
var randomSliceFunction = []byte(`func randomSlice(l, idx int) ([]float64) {
if idx < 0{
idx = -idx
}
s := make([]float64, l * idx)
for i := range s {
s[i] = rand.Float64()
}
return s
}`)
const (
posInc1 = 5
posInc2 = 3
negInc1 = -3
negInc2 = -4
)
var level1Sizes = []struct {
lower string
upper string
camel string
size int
}{
{
lower: "small",
upper: "SMALL_SLICE",
camel: "Small",
size: 10,
},
{
lower: "medium",
upper: "MEDIUM_SLICE",
camel: "Medium",
size: 1000,
},
{
lower: "large",
upper: "LARGE_SLICE",
camel: "Large",
size: 100000,
},
{
lower: "huge",
upper: "HUGE_SLICE",
camel: "Huge",
size: 10000000,
},
}
type level1functionStruct struct {
camel string
sig string
call string
extraSetup string
oneInput bool
extraName string // if have a couple different cases for the same function
}
var level1Functions = []level1functionStruct{
{
camel: "Ddot",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Dnrm2",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Dasum",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Idamax",
sig: "n int, x []float64, incX int",
call: "n, x, incX",
oneInput: true,
},
{
camel: "Dswap",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Dcopy",
sig: "n int, x []float64, incX int, y []float64, incY int",
call: "n, x, incX, y, incY",
oneInput: false,
},
{
camel: "Daxpy",
sig: "n int, alpha float64, x []float64, incX int, y []float64, incY int",
call: "n, alpha, x, incX, y, incY",
extraSetup: "alpha := 2.4",
oneInput: false,
},
{
camel: "Drot",
sig: "n int, x []float64, incX int, y []float64, incY int, c, s float64",
call: "n, x, incX, y, incY, c, s",
extraSetup: "c := 0.89725836967\ns:= 0.44150585279",
oneInput: false,
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{0, -0.625, 0.9375,0}}",
oneInput: false,
extraName: "OffDia",
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{5.0 / 12, 0, 0, 0.625}}",
oneInput: false,
extraName: "Dia",
},
{
camel: "Drotm",
sig: "n int, x []float64, incX int, y []float64, incY int, p blas.DrotmParams",
call: "n, x, incX, y, incY, p",
extraSetup: "p := blas.DrotmParams{Flag: blas.OffDiagonal, H: [4]float64{4096, -3584, 1792, 4096}}",
oneInput: false,
extraName: "Resc",
},
{
camel: "Dscal",
sig: "n int, alpha float64, x []float64, incX int",
call: "n, alpha, x, incX",
extraSetup: "alpha := 2.4",
oneInput: true,
},
}
func init() {
gopath = os.Getenv("GOPATH")
if gopath == "" {
panic("gopath not set")
}
}
func main() {
blasPath := filepath.Join(gopath, "src", "github.com", "gonum", "blas")
pkgs := []struct{ name string }{{name: "native"}, {name: "cgo"}}
for _, pkg := range pkgs {
err := level1(filepath.Join(blasPath, pkg.name), pkg.name)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
err = exec.Command("go", "fmt", path.Join("github.com", "gonum", "blas", pkg.name)).Run()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}
}
func printHeader(f *os.File, name string) error {
if _, err := f.Write(autogen); err != nil {
return err
}
f.WriteString("\n\n")
f.Write(copyrightnotice)
f.WriteString("\n\n")
f.WriteString("package " + name)
f.WriteString("\n\n")
f.Write(imports)
f.WriteString("\n\n")
return nil
}
// Generate the benchmark scripts for level1
func level1(benchPath string, pkgname string) error {
// Generate level 1 benchmarks
level1Filepath := filepath.Join(benchPath, "level1doubleBench_auto_test.go")
f, err := os.Create(level1Filepath)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
defer f.Close()
printHeader(f, pkgname)
// Print all of the constants
f.WriteString("const (\n")
f.WriteString("\tposInc1 = " + strconv.Itoa(posInc1) + "\n")
f.WriteString("\tposInc2 = " + strconv.Itoa(posInc2) + "\n")
f.WriteString("\tnegInc1 = " + strconv.Itoa(negInc1) + "\n")
f.WriteString("\tnegInc2 = " + strconv.Itoa(negInc2) + "\n")
for _, con := range level1Sizes {
f.WriteString("\t" + con.upper + " = " + strconv.Itoa(con.size) + "\n")
}
f.WriteString(")\n")
f.WriteString("\n")
// Write the randomSlice function
f.Write(randomSliceFunction)
f.WriteString("\n\n")
// Start writing the benchmarks
for _, fun := range level1Functions {
writeLevel1Benchmark(fun, f)
f.WriteString("\n/* ------------------ */ \n")
}
return nil
}
func writeLevel1Benchmark(fun level1functionStruct, f *os.File) {
// First, write the base benchmark file
f.WriteString("func benchmark" + fun.camel + fun.extraName + "(b *testing.B, ")
f.WriteString(fun.sig)
f.WriteString(") {\n")
f.WriteString("b.ResetTimer()\n")
f.WriteString("for i := 0; i < b.N; i++{\n")
f.WriteString("\timpl." + fun.camel + "(")
f.WriteString(fun.call)
f.WriteString(")\n}\n}\n")
f.WriteString("\n")
// Write all of the benchmarks to call it
for _, sz := range level1Sizes {
lambda := func(incX, incY, name string, twoInput bool) {
f.WriteString("func Benchmark" + fun.camel + fun.extraName + sz.camel + name + "(b *testing.B){\n")
f.WriteString("n := " + sz.upper + "\n")
f.WriteString("incX := " + incX + "\n")
f.WriteString("x := randomSlice(n, incX)\n")
if twoInput {
f.WriteString("incY := " + incY + "\n")
f.WriteString("y := randomSlice(n, incY)\n")
}
f.WriteString(fun.extraSetup + "\n")
f.WriteString("benchmark" + fun.camel + fun.extraName + "(b, " + fun.call + ")\n")
f.WriteString("}\n\n")
}
if fun.oneInput {
lambda("1", "", "UnitaryInc", false)
lambda("posInc1", "", "PosInc", false)
} else {
lambda("1", "1", "BothUnitary", true)
lambda("posInc1", "1", "IncUni", true)
lambda("1", "negInc1", "UniInc", true)
lambda("posInc1", "negInc1", "BothInc", true)
}
}
}

View File

@@ -0,0 +1,8 @@
package testblas
const (
SmallMat = 10
MediumMat = 100
LargeMat = 1000
HugeMat = 10000
)

237
blas/testblas/common.go Normal file
View File

@@ -0,0 +1,237 @@
package testblas
import (
"math"
"testing"
"github.com/gonum/blas"
)
// throwPanic will throw unexpected panics if true, or will just report them as errors if false
const throwPanic = true
func dTolEqual(a, b float64) bool {
if math.IsNaN(a) && math.IsNaN(b) {
return true
}
if a == b {
return true
}
m := math.Max(math.Abs(a), math.Abs(b))
if m > 1 {
a /= m
b /= m
}
if math.Abs(a-b) < 1e-14 {
return true
}
return false
}
func dSliceTolEqual(a, b []float64) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !dTolEqual(a[i], b[i]) {
return false
}
}
return true
}
func dStridedSliceTolEqual(n int, a []float64, inca int, b []float64, incb int) bool {
ia := 0
ib := 0
if inca <= 0 {
ia = -(n - 1) * inca
}
if incb <= 0 {
ib = -(n - 1) * incb
}
for i := 0; i < n; i++ {
if !dTolEqual(a[ia], b[ib]) {
return false
}
ia += inca
ib += incb
}
return true
}
func dSliceEqual(a, b []float64) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !(a[i] == b[i]) {
return false
}
}
return true
}
func dCopyTwoTmp(x, xTmp, y, yTmp []float64) {
if len(x) != len(xTmp) {
panic("x size mismatch")
}
if len(y) != len(yTmp) {
panic("y size mismatch")
}
for i, val := range x {
xTmp[i] = val
}
for i, val := range y {
yTmp[i] = val
}
}
// returns true if the function panics
func panics(f func()) (b bool) {
defer func() {
err := recover()
if err != nil {
b = true
}
}()
f()
return
}
func testpanics(f func(), name string, t *testing.T) {
b := panics(f)
if !b {
t.Errorf("%v should panic and does not", name)
}
}
func sliceOfSliceCopy(a [][]float64) [][]float64 {
n := make([][]float64, len(a))
for i := range a {
n[i] = make([]float64, len(a[i]))
copy(n[i], a[i])
}
return n
}
func sliceCopy(a []float64) []float64 {
n := make([]float64, len(a))
copy(n, a)
return n
}
func flatten(a [][]float64) []float64 {
if len(a) == 0 {
return nil
}
m := len(a)
n := len(a[0])
s := make([]float64, m*n)
for i := 0; i < m; i++ {
for j := 0; j < n; j++ {
s[i*n+j] = a[i][j]
}
}
return s
}
func unflatten(a []float64, m, n int) [][]float64 {
s := make([][]float64, m)
for i := 0; i < m; i++ {
s[i] = make([]float64, n)
for j := 0; j < n; j++ {
s[i][j] = a[i*n+j]
}
}
return s
}
// flattenTriangular turns the upper or lower triangle of a dense slice of slice
// into a single slice with packed storage. a must be a square matrix.
func flattenTriangular(a [][]float64, ul blas.Uplo) []float64 {
m := len(a)
aFlat := make([]float64, m*(m+1)/2)
var k int
if ul == blas.Upper {
for i := 0; i < m; i++ {
k += copy(aFlat[k:], a[i][i:])
}
return aFlat
}
for i := 0; i < m; i++ {
k += copy(aFlat[k:], a[i][:i+1])
}
return aFlat
}
// flattenBanded turns a dense banded slice of slice into the compact banded matrix format
func flattenBanded(a [][]float64, ku, kl int) []float64 {
m := len(a)
n := len(a[0])
if ku < 0 || kl < 0 {
panic("testblas: negative band length")
}
nRows := m
nCols := (ku + kl + 1)
aflat := make([]float64, nRows*nCols)
for i := range aflat {
aflat[i] = math.NaN()
}
// loop over the rows, and then the bands
// elements in the ith row stay in the ith row
// order in bands is kept
for i := 0; i < nRows; i++ {
min := -kl
if i-kl < 0 {
min = -i
}
max := ku
if i+ku >= n {
max = n - i - 1
}
for j := min; j <= max; j++ {
col := kl + j
aflat[i*nCols+col] = a[i][i+j]
}
}
return aflat
}
// makeIncremented takes a slice with inc == 1 and makes an incremented version
// and adds extra values on the end
func makeIncremented(x []float64, inc int, extra int) []float64 {
if inc == 0 {
panic("zero inc")
}
absinc := inc
if absinc < 0 {
absinc = -inc
}
xcopy := make([]float64, len(x))
if inc > 0 {
copy(xcopy, x)
} else {
for i := 0; i < len(x); i++ {
xcopy[i] = x[len(x)-i-1]
}
}
// don't use NaN because it makes comparison hard
// Do use a weird unique value for easier debugging
counter := 100.0
var xnew []float64
for i, v := range xcopy {
xnew = append(xnew, v)
if i != len(x)-1 {
for j := 0; j < absinc-1; j++ {
xnew = append(xnew, counter)
counter++
}
}
}
for i := 0; i < extra; i++ {
xnew = append(xnew, counter)
counter++
}
return xnew
}

View File

@@ -0,0 +1,187 @@
package testblas
import (
"math"
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
func TestFlattenBanded(t *testing.T) {
for i, test := range []struct {
dense [][]float64
ku int
kl int
condensed [][]float64
}{
{
dense: [][]float64{{3}},
ku: 0,
kl: 0,
condensed: [][]float64{{3}},
},
{
dense: [][]float64{
{3, 4, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
},
},
{
dense: [][]float64{
{3, 4, 0, 0, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
},
},
{
dense: [][]float64{
{3, 4, 0},
{0, 5, 8},
{0, 0, 2},
{0, 0, 0},
{0, 0, 0},
},
ku: 1,
kl: 0,
condensed: [][]float64{
{3, 4},
{5, 8},
{2, math.NaN()},
{math.NaN(), math.NaN()},
{math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{3, 4, 6},
{0, 5, 8},
{0, 0, 2},
{0, 0, 0},
{0, 0, 0},
},
ku: 2,
kl: 0,
condensed: [][]float64{
{3, 4, 6},
{5, 8, math.NaN()},
{2, math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{3, 4, 6},
{1, 5, 8},
{0, 6, 2},
{0, 0, 7},
{0, 0, 0},
},
ku: 2,
kl: 1,
condensed: [][]float64{
{math.NaN(), 3, 4, 6},
{1, 5, 8, math.NaN()},
{6, 2, math.NaN(), math.NaN()},
{7, math.NaN(), math.NaN(), math.NaN()},
{math.NaN(), math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 2, 0},
{3, 4, 5},
{6, 7, 8},
{0, 9, 10},
{0, 0, 11},
},
ku: 1,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1, 2},
{math.NaN(), 3, 4, 5},
{6, 7, 8, math.NaN()},
{9, 10, math.NaN(), math.NaN()},
{11, math.NaN(), math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 0, 0},
{3, 4, 0},
{6, 7, 8},
{0, 9, 10},
{0, 0, 11},
},
ku: 0,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1},
{math.NaN(), 3, 4},
{6, 7, 8},
{9, 10, math.NaN()},
{11, math.NaN(), math.NaN()},
},
},
{
dense: [][]float64{
{1, 0, 0, 0, 0},
{3, 4, 0, 0, 0},
{1, 3, 5, 0, 0},
},
ku: 0,
kl: 2,
condensed: [][]float64{
{math.NaN(), math.NaN(), 1},
{math.NaN(), 3, 4},
{1, 3, 5},
},
},
} {
condensed := flattenBanded(test.dense, test.ku, test.kl)
correct := flatten(test.condensed)
if !floats.Same(condensed, correct) {
t.Errorf("Case %v mismatch. Want %v, got %v.", i, correct, condensed)
}
}
}
func TestFlattenTriangular(t *testing.T) {
for i, test := range []struct {
a [][]float64
ans []float64
ul blas.Uplo
}{
{
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
ul: blas.Upper,
ans: []float64{1, 2, 3, 4, 5, 6},
},
{
a: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
ul: blas.Lower,
ans: []float64{1, 2, 3, 4, 5, 6},
},
} {
a := flattenTriangular(test.a, test.ul)
if !floats.Equal(a, test.ans) {
t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
}
}
}

94
blas/testblas/dgbmv.go Normal file
View File

@@ -0,0 +1,94 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dgbmver interface {
Dgbmv(tA blas.Transpose, m, n, kL, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DgbmvTest(t *testing.T, blasser Dgbmver) {
for i, test := range []struct {
tA blas.Transpose
m, n int
kL, kU int
alpha float64
a [][]float64
lda int
x []float64
beta float64
y []float64
ans []float64
}{
{
tA: blas.NoTrans,
m: 9,
n: 6,
lda: 4,
kL: 2,
kU: 1,
alpha: 3.0,
beta: 2.0,
a: [][]float64{
{5, 3, 0, 0, 0, 0},
{-1, 2, 9, 0, 0, 0},
{4, 8, 3, 6, 0, 0},
{0, -1, 8, 2, 1, 0},
{0, 0, 9, 9, 9, 5},
{0, 0, 0, 2, -3, 2},
{0, 0, 0, 0, 1, 5},
{0, 0, 0, 0, 0, 6},
{0, 0, 0, 0, 0, 0},
},
x: []float64{1, 2, 3, 4, 5, 6},
y: []float64{-1, -2, -3, -4, -5, -6, -7, -8, -9},
ans: []float64{31, 86, 153, 97, 404, 3, 91, 92, -18},
},
{
tA: blas.Trans,
m: 9,
n: 6,
lda: 4,
kL: 2,
kU: 1,
alpha: 3.0,
beta: 2.0,
a: [][]float64{
{5, 3, 0, 0, 0, 0},
{-1, 2, 9, 0, 0, 0},
{4, 8, 3, 6, 0, 0},
{0, -1, 8, 2, 1, 0},
{0, 0, 9, 9, 9, 5},
{0, 0, 0, 2, -3, 2},
{0, 0, 0, 0, 1, 5},
{0, 0, 0, 0, 0, 6},
{0, 0, 0, 0, 0, 0},
},
x: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9},
y: []float64{-1, -2, -3, -4, -5, -6},
ans: []float64{43, 77, 306, 241, 104, 348},
},
} {
extra := 3
aFlat := flattenBanded(test.a, test.kU, test.kL)
incTest := func(incX, incY, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ynew := makeIncremented(test.y, incY, extra)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dgbmv(test.tA, test.m, test.n, test.kL, test.kU, test.alpha, aFlat, test.lda, xnew, incX, test.beta, ynew, incY)
if !dSliceTolEqual(ans, ynew) {
t.Errorf("Case %v: Want %v, got %v", i, ans, ynew)
}
}
incTest(1, 1, extra)
incTest(1, 3, extra)
incTest(1, -3, extra)
incTest(2, 3, extra)
incTest(2, -3, extra)
incTest(3, 2, extra)
incTest(-3, 2, extra)
}
}

252
blas/testblas/dgemm.go Normal file
View File

@@ -0,0 +1,252 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dgemmer interface {
Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
type DgemmCase struct {
isATrans bool
m, n, k int
alpha, beta float64
a [][]float64
aTrans [][]float64 // transpose of a
b [][]float64
c [][]float64
ans [][]float64
}
var DgemmCases = []DgemmCase{
{
m: 4,
n: 3,
k: 2,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2},
{4, 5},
{7, 8},
{10, 11},
},
b: [][]float64{
{1, 5, 6},
{5, -8, 8},
},
c: [][]float64{
{4, 8, -9},
{12, 16, -8},
{1, 5, 15},
{-3, -4, 7},
},
ans: [][]float64{
{24, -18, 39.5},
{64, -32, 124},
{94.5, -55.5, 219.5},
{128.5, -78, 299.5},
},
},
{
m: 4,
n: 2,
k: 3,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
{10, 11, 12},
},
b: [][]float64{
{1, 5},
{5, -8},
{6, 2},
},
c: [][]float64{
{4, 8},
{12, 16},
{1, 5},
{-3, -4},
},
ans: [][]float64{
{60, -6},
{136, -8},
{202.5, -19.5},
{272.5, -30},
},
},
{
m: 3,
n: 2,
k: 4,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3, 4},
{4, 5, 6, 7},
{8, 9, 10, 11},
},
b: [][]float64{
{1, 5},
{5, -8},
{6, 2},
{8, 10},
},
c: [][]float64{
{4, 8},
{12, 16},
{9, -10},
},
ans: [][]float64{
{124, 74},
{248, 132},
{406.5, 191},
},
},
{
m: 3,
n: 4,
k: 2,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2},
{4, 5},
{8, 9},
},
b: [][]float64{
{1, 5, 2, 1},
{5, -8, 2, 1},
},
c: [][]float64{
{4, 8, 2, 2},
{12, 16, 8, 9},
{9, -10, 10, 10},
},
ans: [][]float64{
{24, -18, 13, 7},
{64, -32, 40, 22.5},
{110.5, -69, 73, 39},
},
},
{
m: 2,
n: 4,
k: 3,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3},
{4, 5, 6},
},
b: [][]float64{
{1, 5, 8, 8},
{5, -8, 9, 10},
{6, 2, -3, 2},
},
c: [][]float64{
{4, 8, 7, 8},
{12, 16, -2, 6},
},
ans: [][]float64{
{60, -6, 37.5, 72},
{136, -8, 117, 191},
},
},
{
m: 2,
n: 3,
k: 4,
isATrans: false,
alpha: 2,
beta: 0.5,
a: [][]float64{
{1, 2, 3, 4},
{4, 5, 6, 7},
},
b: [][]float64{
{1, 5, 8},
{5, -8, 9},
{6, 2, -3},
{8, 10, 2},
},
c: [][]float64{
{4, 8, 1},
{12, 16, 6},
},
ans: [][]float64{
{124, 74, 50.5},
{248, 132, 149},
},
},
}
// assumes [][]float64 is actually a matrix
func transpose(a [][]float64) [][]float64 {
b := make([][]float64, len(a[0]))
for i := range b {
b[i] = make([]float64, len(a))
for j := range b[i] {
b[i][j] = a[j][i]
}
}
return b
}
func TestDgemm(t *testing.T, blasser Dgemmer) {
for i, test := range DgemmCases {
// Test that it passes row major
dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
// Try with A transposed
dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
// Try with B transposed
dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
// Try with both transposed
dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
}
}
func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
aFlat := flatten(a)
aCopy := flatten(a)
bFlat := flatten(b)
bCopy := flatten(b)
cFlat := flatten(c)
ansFlat := flatten(ans)
lda := len(a[0])
ldb := len(b[0])
ldc := len(c[0])
// Compute the matrix multiplication
blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
if !dSliceEqual(aFlat, aCopy) {
t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
}
if !dSliceEqual(bFlat, bCopy) {
t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
}
if !dSliceTolEqual(ansFlat, cFlat) {
t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
}
// TODO: Need to add a sub-slice test where don't use up full matrix
}

View File

@@ -0,0 +1,39 @@
package testblas
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func DgemmBenchmark(b *testing.B, dgemm Dgemmer, m, n, k int, tA, tB blas.Transpose) {
a := make([]float64, m*k)
for i := range a {
a[i] = rand.Float64()
}
bv := make([]float64, k*n)
for i := range bv {
bv[i] = rand.Float64()
}
c := make([]float64, m*n)
for i := range c {
c[i] = rand.Float64()
}
var lda, ldb int
if tA == blas.Trans {
lda = m
} else {
lda = k
}
if tB == blas.Trans {
ldb = k
} else {
ldb = n
}
ldc := n
b.ResetTimer()
for i := 0; i < b.N; i++ {
dgemm.Dgemm(tA, tB, m, n, k, 3.0, a, lda, bv, ldb, 1.0, c, ldc)
}
}

680
blas/testblas/dgemv.go Normal file
View File

@@ -0,0 +1,680 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type DgemvCase struct {
Name string
m int
n int
A [][]float64
tA blas.Transpose
x []float64
incX int
y []float64
incY int
xCopy []float64
yCopy []float64
Subcases []DgemvSubcase
}
type DgemvSubcase struct {
mulXNeg1 bool
mulYNeg1 bool
alpha float64
beta float64
ans []float64
}
var DgemvCases = []DgemvCase{
{
Name: "M_gt_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 8, 9, 10, 11},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 8, 9, 10, 11},
},
{
alpha: 1,
beta: 0,
ans: []float64{40.8, 43.9, 33, 9, 28},
},
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 303.2, 210, 12, 158},
},
},
},
{
Name: "M_gt_N_Inc1_Trans",
tA: blas.Trans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3, -4, 5},
y: []float64{7, 8, 9},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 8, 9},
},
{
alpha: 1,
beta: 0,
ans: []float64{94.3, 40.2, 52.3},
},
{
alpha: 8,
beta: -6,
ans: []float64{712.4, 273.6, 364.4},
},
},
},
{
Name: "M_eq_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 2, 2},
},
{
alpha: 1,
beta: 0,
ans: []float64{40.8, 43.9, 33},
},
{
alpha: 8,
beta: -6,
ans: []float64{40.8*8 - 6*7, 43.9*8 - 6*2, 33*8 - 6*2},
},
},
},
{
Name: "M_eq_N_Inc1_Trans",
tA: blas.Trans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 261.6, 270.4},
},
},
},
{
Name: "M_lt_N_Inc1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 7},
{9.6, 3.5, 9.1, -2, 9},
{10, 7, 3, 1, -5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3, -7.6, 8.1},
y: []float64{7, 2, 2},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 0, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 2, 2},
},
{
alpha: 1,
beta: 0,
ans: []float64{21.5, 132, -15.1},
},
{
alpha: 8,
beta: -6,
ans: []float64{21.5*8 - 6*7, 132*8 - 6*2, -15.1*8 - 6*2},
},
},
},
{
Name: "M_lt_N_Inc1_Trans",
tA: blas.Trans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 7},
{9.6, 3.5, 9.1, -2, 9},
{10, 7, 3, 1, -5},
},
incX: 1,
incY: 1,
x: []float64{1, 2, 3},
y: []float64{7, 2, 2, -3, 5},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 261.6, 270.4, 90, 50},
},
},
},
{
Name: "M_gt_N_Part1_NoTrans",
tA: blas.NoTrans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 2,
x: []float64{1, 2, 3},
y: []float64{7, 100, 8, 101, 9, 102, 10, 103, 11},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 100, 0, 101, 0, 102, 0, 103, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 100, 8, 101, 9, 102, 10, 103, 11},
},
{
alpha: 1,
beta: 0,
ans: []float64{40.8, 100, 43.9, 101, 33, 102, 9, 103, 28},
},
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 100, 303.2, 101, 210, 102, 12, 103, 158},
},
},
},
{
Name: "M_gt_N_Part1_Trans",
tA: blas.Trans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 1,
incY: 2,
x: []float64{1, 2, 3, -4, 5},
y: []float64{7, 100, 8, 101, 9},
Subcases: []DgemvSubcase{
{
alpha: 0,
beta: 0,
ans: []float64{0, 100, 0, 101, 0},
},
{
alpha: 0,
beta: 1,
ans: []float64{7, 100, 8, 101, 9},
},
{
alpha: 1,
beta: 0,
ans: []float64{94.3, 100, 40.2, 101, 52.3},
},
{
alpha: 8,
beta: -6,
ans: []float64{712.4, 100, 273.6, 101, 364.4},
},
},
},
{
Name: "M_gt_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9, 1, 1, 10, 19, 22, 11},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 2, 6, 303.2, -4, -5, 210, 1, 1, 12, 19, 22, 158},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{220.4, 2, 6, 311.2, -4, -5, 322, 1, 1, -4, 19, 22, 222},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{182, 2, 6, 24, -4, -5, 210, 1, 1, 291.2, 19, 22, 260.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{246, 2, 6, 8, -4, -5, 322, 1, 1, 299.2, 19, 22, 196.4},
},
},
},
{
Name: "M_gt_N_IncNot1_Trans",
tA: blas.Trans,
m: 5,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
{1, 1, 2},
{9, 2, 5},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3, 8, -3, 6, 5},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{720.4, 2, 6, 281.6, -4, -5, 380.4},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{219.6, 2, 6, 316, -4, -5, 195.6},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{392.4, 2, 6, 281.6, -4, -5, 708.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{207.6, 2, 6, 316, -4, -5, 207.6},
},
},
},
{
Name: "M_eq_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{284.4, 2, 6, 303.2, -4, -5, 210},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{220.4, 2, 6, 311.2, -4, -5, 322},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{222, 2, 6, 303.2, -4, -5, 272.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{334, 2, 6, 311.2, -4, -5, 208.4},
},
},
},
{
Name: "M_eq_N_IncNot1_Trans",
tA: blas.Trans,
m: 3,
n: 3,
A: [][]float64{
{4.1, 6.2, 8.1},
{9.6, 3.5, 9.1},
{10, 7, 3},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 2, 6, 225.6, -4, -5, 228.4},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{290, 2, 6, 212.8, -4, -5, 310},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{240.4, 2, 6, 225.6, -4, -5, 372.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{322, 2, 6, 212.8, -4, -5, 278},
},
},
},
{
Name: "M_lt_N_IncNot1_NoTrans",
tA: blas.NoTrans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 11},
{9.6, 3.5, 9.1, -3, -2},
{10, 7, 3, -7, -4},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3, -2, -4, 8, -9},
y: []float64{7, 2, 6, 8, -4, -5, 9},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{-827.6, 2, 6, 543.2, -4, -5, 722},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-93.2, 2, 6, -696.8, -4, -5, -1070},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{734, 2, 6, 543.2, -4, -5, -839.6},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-1058, 2, 6, -696.8, -4, -5, -105.2},
},
},
},
{
Name: "M_lt_N_IncNot1_Trans",
tA: blas.Trans,
m: 3,
n: 5,
A: [][]float64{
{4.1, 6.2, 8.1, 10, 11},
{9.6, 3.5, 9.1, -3, -2},
{10, 7, 3, -7, -4},
},
incX: 2,
incY: 3,
x: []float64{1, 15, 2, 150, 3},
y: []float64{7, 2, 6, 8, -4, -5, 9, -4, -1, -9, 1, 1, 2},
Subcases: []DgemvSubcase{
{
alpha: 8,
beta: -6,
ans: []float64{384.4, 2, 6, 225.6, -4, -5, 228.4, -4, -1, -82, 1, 1, -52},
},
{
mulXNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{290, 2, 6, 212.8, -4, -5, 310, -4, -1, 190, 1, 1, 188},
},
{
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{-82, 2, 6, -184, -4, -5, 228.4, -4, -1, 327.6, 1, 1, 414.4},
},
{
mulXNeg1: true,
mulYNeg1: true,
alpha: 8,
beta: -6,
ans: []float64{158, 2, 6, 88, -4, -5, 310, -4, -1, 314.8, 1, 1, 320},
},
},
},
// TODO: A can be longer than mxn. Add cases where it is longer
// TODO: x and y can also be longer. Add tests for these
// TODO: Add tests for dimension mismatch
// TODO: Add places with a "submatrix view", where lda != m
}
type Dgemver interface {
Dgemv(tA blas.Transpose, m, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DgemvTest(t *testing.T, blasser Dgemver) {
for _, test := range DgemvCases {
for i, cas := range test.Subcases {
// Test that it passes with row-major
dgemvcomp(t, test, cas, i, blasser)
// Test the bad inputs
dgemvbad(t, test, cas, i, blasser)
}
}
}
func dgemvcomp(t *testing.T, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) {
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.A)
aFlat := flatten(a)
lda := test.n
incX := test.incX
if cas.mulXNeg1 {
incX *= -1
}
incY := test.incY
if cas.mulYNeg1 {
incY *= -1
}
f := func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, incX, cas.beta, y, incY)
}
if panics(f) {
t.Errorf("Test %v case %v: unexpected panic", test.Name, i)
if throwPanic {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, incX, cas.beta, y, incY)
}
return
}
// Check that x and a are unchanged
if !dSliceEqual(x, test.x) {
t.Errorf("Test %v, case %v: x modified during call", test.Name, i)
}
aFlat2 := flatten(sliceOfSliceCopy(test.A))
if !dSliceEqual(aFlat2, aFlat) {
t.Errorf("Test %v, case %v: a modified during call", test.Name, i)
}
// Check that the answer matches
if !dSliceTolEqual(cas.ans, y) {
t.Errorf("Test %v, case %v: answer mismatch: Expected %v, Found %v", test.Name, i, cas.ans, y)
}
}
func dgemvbad(t *testing.T, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) {
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.A)
aFlatRow := flatten(a)
ldaRow := test.n
f := func() {
blasser.Dgemv(312, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for bad transpose", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, -2, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for m negative", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, -4, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for n negative", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, 0, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for incX zero", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, 0)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for incY zero", test.Name, i)
}
f = func() {
blasser.Dgemv(test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow-1, x, test.incX, cas.beta, y, test.incY)
}
if !panics(f) {
t.Errorf("Test %v case %v: no panic for lda too small row major", test.Name, i)
}
}

164
blas/testblas/dger.go Normal file
View File

@@ -0,0 +1,164 @@
package testblas
import "testing"
type Dgerer interface {
Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
}
func DgerTest(t *testing.T, blasser Dgerer) {
for _, test := range []struct {
name string
a [][]float64
m int
n int
x []float64
y []float64
incX int
incY int
ansAlphaEq1 []float64
trueAns [][]float64
}{
{
name: "M gt N inc 1",
m: 5,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
{8, 9, -10},
{-12, -14, -6},
},
x: []float64{-2, -3, 0, 1, 2},
y: []float64{-1.1, 5, 0},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}, {6.9, 14, -10}, {-14.2, -4, -6}},
},
{
name: "M eq N inc 1",
m: 3,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
},
x: []float64{-2, -3, 0},
y: []float64{-1.1, 5, 0},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}},
},
{
name: "M lt N inc 1",
m: 3,
n: 6,
a: [][]float64{
{1.3, 2.4, 3.5, 4.8, 1.11, -9},
{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
},
x: []float64{-2, -3, 0},
y: []float64{-1.1, 5, 0, 9, 19, 22},
incX: 1,
incY: 1,
trueAns: [][]float64{{3.5, -7.6, 3.5, -13.2, -36.89, -53}, {5.9, -12.2, 3.3, -30.4, -50.8, -74.7}, {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9}},
},
{
name: "M gt N inc not 1",
m: 5,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
{8, 9, -10},
{-12, -14, -6},
},
x: []float64{-2, -3, 0, 1, 2, 6, 0, 9, 7},
y: []float64{-1.1, 5, 0, 8, 7, -5, 7},
incX: 2,
incY: 3,
trueAns: [][]float64{{3.5, -13.6, -10.5}, {2.6, 2.8, 3.3}, {-3.5, 11.7, 4.3}, {8, 9, -10}, {-19.700000000000003, 42, 43}},
},
{
name: "M eq N inc not 1",
m: 3,
n: 3,
a: [][]float64{
{1.3, 2.4, 3.5},
{2.6, 2.8, 3.3},
{-1.3, -4.3, -9.7},
},
x: []float64{-2, -3, 0, 8, 7, -9, 7, -6, 12, 6, 6, 6, -11},
y: []float64{-1.1, 5, 0, 0, 9, 8, 6},
incX: 4,
incY: 3,
trueAns: [][]float64{{3.5, 2.4, -8.5}, {-5.1, 2.8, 45.3}, {-14.5, -4.3, 62.3}},
},
{
name: "M lt N inc not 1",
m: 3,
n: 6,
a: [][]float64{
{1.3, 2.4, 3.5, 4.8, 1.11, -9},
{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
},
x: []float64{-2, -3, 0, 0, 8, 0, 9, -3},
y: []float64{-1.1, 5, 0, 9, 19, 22, 11, -8.11, -9.22, 9.87, 7},
incX: 3,
incY: 2,
trueAns: [][]float64{{3.5, 2.4, -34.5, -17.2, 19.55, -23}, {2.6, 2.8, 3.3, -3.4, 6.2, -8.7}, {-11.2, -4.3, 161.3, 95.9, -74.08, 71.9}},
},
} {
// TODO: Add tests where a is longer
// TODO: Add panic tests
// TODO: Add negative increment tests
x := sliceCopy(test.x)
y := sliceCopy(test.y)
a := sliceOfSliceCopy(test.a)
// Test with row major
alpha := 1.0
aFlat := flatten(a)
blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
ans := unflatten(aFlat, test.m, test.n)
dgercomp(t, x, test.x, y, test.y, ans, test.trueAns, test.name+" row maj")
// Test with different alpha
alpha = 4.0
aFlat = flatten(a)
blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
ans = unflatten(aFlat, test.m, test.n)
trueCopy := sliceOfSliceCopy(test.trueAns)
for i := range trueCopy {
for j := range trueCopy[i] {
trueCopy[i][j] = alpha*(trueCopy[i][j]-a[i][j]) + a[i][j]
}
}
dgercomp(t, x, test.x, y, test.y, ans, trueCopy, test.name+" row maj alpha")
}
}
func dgercomp(t *testing.T, x, xCopy, y, yCopy []float64, ans [][]float64, trueAns [][]float64, name string) {
if !dSliceEqual(x, xCopy) {
t.Errorf("case %v: x modified during call to dger", name)
}
if !dSliceEqual(y, yCopy) {
t.Errorf("case %v: x modified during call to dger", name)
}
for i := range ans {
if !dSliceTolEqual(ans[i], trueAns[i]) {
t.Errorf("case %v: answer mismatch. Expected %v, Found %v", name, trueAns, ans)
break
}
}
}

83
blas/testblas/dsbmv.go Normal file
View File

@@ -0,0 +1,83 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsbmver interface {
Dsbmv(ul blas.Uplo, n, k int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DsbmvTest(t *testing.T, blasser Dsbmver) {
for i, test := range []struct {
ul blas.Uplo
n int
k int
alpha float64
beta float64
a [][]float64
x []float64
y []float64
ans []float64
}{
{
ul: blas.Upper,
n: 4,
k: 2,
alpha: 2,
beta: 3,
a: [][]float64{
{7, 8, 2, 0},
{0, 8, 2, -3},
{0, 0, 3, 6},
{0, 0, 0, 9},
},
x: []float64{1, 2, 3, 4},
y: []float64{-1, -2, -3, -4},
ans: []float64{55, 30, 69, 84},
},
{
ul: blas.Lower,
n: 4,
k: 2,
alpha: 2,
beta: 3,
a: [][]float64{
{7, 0, 0, 0},
{8, 8, 0, 0},
{2, 2, 3, 0},
{0, -3, 6, 9},
},
x: []float64{1, 2, 3, 4},
y: []float64{-1, -2, -3, -4},
ans: []float64{55, 30, 69, 84},
},
} {
extra := 0
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
incTest := func(incX, incY, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ynew := makeIncremented(test.y, incY, extra)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dsbmv(test.ul, test.n, test.k, test.alpha, aFlat, test.k+1, xnew, incX, test.beta, ynew, incY)
if !dSliceTolEqual(ans, ynew) {
t.Errorf("Case %v: Want %v, got %v", i, ans, ynew)
}
}
incTest(1, 1, extra)
incTest(1, 3, extra)
incTest(1, -3, extra)
incTest(2, 3, extra)
incTest(2, -3, extra)
incTest(3, 2, extra)
incTest(-3, 2, extra)
}
}

73
blas/testblas/dspmv.go Normal file
View File

@@ -0,0 +1,73 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dspmver interface {
Dspmv(ul blas.Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int)
}
func DspmvTest(t *testing.T, blasser Dspmver) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
y []float64
alpha float64
beta float64
ans []float64
}{
{
ul: blas.Upper,
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 8, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 8, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
} {
incTest := func(incX, incY, extra int) {
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
aFlat := flattenTriangular(test.a, test.ul)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dspmv(test.ul, test.n, test.alpha, aFlat, x, incX, test.beta, y, incY)
if !floats.EqualApprox(ans, y, 1e-14) {
t.Errorf("Case %v, incX=%v, incY=%v: Want %v, got %v.", i, incX, incY, ans, y)
}
}
incTest(1, 1, 0)
incTest(2, 3, 0)
incTest(3, 2, 0)
incTest(-3, 2, 0)
incTest(-2, 4, 0)
incTest(2, -1, 0)
incTest(-3, -4, 3)
}
}

71
blas/testblas/dspr.go Normal file
View File

@@ -0,0 +1,71 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsprer interface {
Dspr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64)
}
func DsprTest(t *testing.T, blasser Dsprer) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
alpha float64
ans [][]float64
}{
{
ul: blas.Upper,
n: 4,
a: [][]float64{
{10, 2, 0, 1},
{0, 1, 2, 3},
{0, 0, 9, 15},
{0, 0, 0, -6},
},
x: []float64{1, 2, 0, 5},
alpha: 8,
ans: [][]float64{
{18, 18, 0, 41},
{0, 33, 2, 83},
{0, 0, 9, 15},
{0, 0, 0, 194},
},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{10, 2, 0},
{4, 1, 2},
{2, 7, 9},
},
x: []float64{3, 0, 5},
alpha: 8,
ans: [][]float64{
{82, 2, 0},
{4, 1, 2},
{122, 7, 209},
},
},
} {
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
aFlat := flattenTriangular(test.a, test.ul)
ans := flattenTriangular(test.ans, test.ul)
blasser.Dspr(test.ul, test.n, test.alpha, xnew, incX, aFlat)
if !dSliceTolEqual(aFlat, ans) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, aFlat)
}
}
incTest(1, 3)
incTest(1, 0)
incTest(3, 2)
incTest(-2, 2)
}
}

76
blas/testblas/dspr2.go Normal file
View File

@@ -0,0 +1,76 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dspr2er interface {
Dspr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64)
}
func Dspr2Test(t *testing.T, blasser Dspr2er) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
x []float64
y []float64
alpha float64
ans [][]float64
}{
{
n: 3,
a: [][]float64{
{7, 2, 4},
{0, 3, 5},
{0, 0, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Upper,
ans: [][]float64{
{47, 56, 72},
{0, 75, 95},
{0, 0, 118},
},
},
{
n: 3,
a: [][]float64{
{7, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Lower,
ans: [][]float64{
{47, 0, 0},
{56, 75, 0},
{72, 95, 118},
},
},
} {
incTest := func(incX, incY, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
blasser.Dspr2(test.ul, test.n, test.alpha, x, incX, y, incY, aFlat)
ansFlat := flattenTriangular(test.ans, test.ul)
if !floats.EqualApprox(aFlat, ansFlat, 1e-14) {
t.Errorf("Case %v, incX = %v, incY = %v. Want %v, got %v.", i, incX, incY, ansFlat, aFlat)
}
}
incTest(1, 1, 0)
incTest(-2, 1, 0)
incTest(-2, 3, 0)
incTest(2, -3, 0)
incTest(3, -2, 0)
incTest(-3, -4, 0)
}
}

277
blas/testblas/dsymm.go Normal file
View File

@@ -0,0 +1,277 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsymmer interface {
Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
func DsymmTest(t *testing.T, blasser Dsymmer) {
for i, test := range []struct {
m int
n int
side blas.Side
ul blas.Uplo
a [][]float64
b [][]float64
c [][]float64
alpha float64
beta float64
ans [][]float64
}{
{
side: blas.Left,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 3, 4},
{0, 6, 7},
{0, 0, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 3, 4, 8},
{0, 6, 7, 9},
{0, 0, 10, 10},
{0, 0, 0, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4, 8},
{5, 6, 7, 15},
{8, 9, 10, 20},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{126, 156, 144, 285},
{211, 252, 275, 535},
{282, 291, 327, 689},
},
},
{
side: blas.Left,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{8, 9, 10, 11},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{158, 172, 160},
{247, 270, 293},
{322, 311, 347},
{329, 385, 427},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 9},
{5, 6, 7, -3},
{8, 9, 10, -2},
},
c: [][]float64{
{8, 12, 2, 10},
{9, 12, 9, 10},
{12, 1, -1, 10},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86, 138},
{47, 108, 167, -6},
{68, 111, 197, 6},
},
},
{
side: blas.Right,
ul: blas.Upper,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{32, 72, 86},
{47, 108, 167},
{68, 111, 197},
{11, 39, 35},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 3,
n: 4,
a: [][]float64{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 10, 0},
{3, 4, 5, 6},
},
b: [][]float64{
{2, 3, 4, 2},
{5, 6, 7, 1},
{8, 9, 10, 1},
},
c: [][]float64{
{8, 12, 2, 1},
{9, 12, 9, 9},
{12, 1, -1, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{94, 156, 164, 103},
{145, 244, 301, 187},
{208, 307, 397, 247},
},
},
{
side: blas.Right,
ul: blas.Lower,
m: 4,
n: 3,
a: [][]float64{
{2, 0, 0},
{3, 6, 0},
{4, 7, 10},
},
b: [][]float64{
{2, 3, 4},
{5, 6, 7},
{8, 9, 10},
{2, 1, 1},
},
c: [][]float64{
{8, 12, 2},
{9, 12, 9},
{12, 1, -1},
{1, 9, 5},
},
alpha: 2,
beta: 3,
ans: [][]float64{
{82, 140, 144},
{139, 236, 291},
{202, 299, 387},
{25, 65, 65},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsymm(test.side, test.ul, test.m, test.n, test.alpha, aFlat, len(test.a[0]), bFlat, test.n, test.beta, cFlat, test.n)
if !floats.EqualApprox(cFlat, ansFlat, 1e-14) {
t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

73
blas/testblas/dsymv.go Normal file
View File

@@ -0,0 +1,73 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsymver interface {
Dsymv(ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int)
}
func DsymvTest(t *testing.T, blasser Dsymver) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
y []float64
alpha float64
beta float64
ans []float64
}{
{
ul: blas.Upper,
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 8, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 8, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
y: []float64{6, 7, 8},
alpha: 2.1,
beta: -3,
ans: []float64{137.4, 189, 240.6},
},
} {
incTest := func(incX, incY, extra int) {
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
aFlat := flatten(test.a)
ans := makeIncremented(test.ans, incY, extra)
blasser.Dsymv(test.ul, test.n, test.alpha, aFlat, test.n, x, incX, test.beta, y, incY)
if !floats.EqualApprox(ans, y, 1e-14) {
t.Errorf("Case %v, incX=%v, incY=%v: Want %v, got %v.", i, incX, incY, ans, y)
}
}
incTest(1, 1, 0)
incTest(2, 3, 0)
incTest(3, 2, 0)
incTest(-3, 2, 0)
incTest(-2, 4, 0)
incTest(2, -1, 0)
incTest(-3, -4, 3)
}
}

72
blas/testblas/dsyr.go Normal file
View File

@@ -0,0 +1,72 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dsyrer interface {
Dsyr(ul blas.Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int)
}
func DsyrTest(t *testing.T, blasser Dsyrer) {
for i, test := range []struct {
ul blas.Uplo
n int
a [][]float64
x []float64
alpha float64
ans [][]float64
}{
{
ul: blas.Upper,
n: 4,
a: [][]float64{
{10, 2, 0, 1},
{0, 1, 2, 3},
{0, 0, 9, 15},
{0, 0, 0, -6},
},
x: []float64{1, 2, 0, 5},
alpha: 8,
ans: [][]float64{
{18, 18, 0, 41},
{0, 33, 2, 83},
{0, 0, 9, 15},
{0, 0, 0, 194},
},
},
{
ul: blas.Lower,
n: 3,
a: [][]float64{
{10, 2, 0},
{4, 1, 2},
{2, 7, 9},
},
x: []float64{3, 0, 5},
alpha: 8,
ans: [][]float64{
{82, 2, 0},
{4, 1, 2},
{122, 7, 209},
},
},
} {
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
aFlat := flatten(test.a)
ans := flatten(test.ans)
lda := test.n
blasser.Dsyr(test.ul, test.n, test.alpha, xnew, incX, aFlat, lda)
if !dSliceTolEqual(aFlat, ans) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, aFlat)
}
}
incTest(1, 3)
incTest(1, 0)
incTest(3, 2)
incTest(-2, 2)
}
}

76
blas/testblas/dsyr2.go Normal file
View File

@@ -0,0 +1,76 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyr2er interface {
Dsyr2(ul blas.Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
}
func Dsyr2Test(t *testing.T, blasser Dsyr2er) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
x []float64
y []float64
alpha float64
ans [][]float64
}{
{
n: 3,
a: [][]float64{
{7, 2, 4},
{0, 3, 5},
{0, 0, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Upper,
ans: [][]float64{
{47, 56, 72},
{0, 75, 95},
{0, 0, 118},
},
},
{
n: 3,
a: [][]float64{
{7, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
x: []float64{2, 3, 4},
y: []float64{5, 6, 7},
alpha: 2,
ul: blas.Lower,
ans: [][]float64{
{47, 0, 0},
{56, 75, 0},
{72, 95, 118},
},
},
} {
incTest := func(incX, incY, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
y := makeIncremented(test.y, incY, extra)
blasser.Dsyr2(test.ul, test.n, test.alpha, x, incX, y, incY, aFlat, test.n)
ansFlat := flatten(test.ans)
if !floats.EqualApprox(aFlat, ansFlat, 1e-14) {
t.Errorf("Case %v, incX = %v, incY = %v. Want %v, got %v.", i, incX, incY, ansFlat, aFlat)
}
}
incTest(1, 1, 0)
incTest(-2, 1, 0)
incTest(-2, 3, 0)
incTest(2, -3, 0)
incTest(3, -2, 0)
incTest(-3, -4, 0)
}
}

201
blas/testblas/dsyr2k.go Normal file
View File

@@ -0,0 +1,201 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyr2ker interface {
Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
}
func Dsyr2kTest(t *testing.T, blasser Dsyr2ker) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
n int
k int
alpha float64
a [][]float64
b [][]float64
c [][]float64
beta float64
ans [][]float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 5, 6},
{0, 0, 9},
},
beta: 2,
ans: [][]float64{
{2, 4, 6},
{0, 10, 12},
{0, 0, 18},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
beta: 2,
ans: [][]float64{
{2, 0, 0},
{4, 6, 0},
{8, 10, 12},
},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{140, 250, 360},
{0, 410, 568},
{0, 0, 774},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
b: [][]float64{
{7, 8},
{9, 10},
{11, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{140, 0, 0},
{250, 410, 0},
{360, 568, 774},
},
},
{
ul: blas.Upper,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
b: [][]float64{
{7, 9, 11},
{8, 10, 12},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{140, 250, 360},
{0, 410, 568},
{0, 0, 774},
},
},
{
ul: blas.Lower,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
b: [][]float64{
{7, 9, 11},
{8, 10, 12},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{140, 0, 0},
{250, 410, 0},
{360, 568, 774},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsyr2k(test.ul, test.tA, test.n, test.k, test.alpha, aFlat, len(test.a[0]), bFlat, len(test.b[0]), test.beta, cFlat, len(test.c[0]))
if !floats.EqualApprox(ansFlat, cFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

171
blas/testblas/dsyrk.go Normal file
View File

@@ -0,0 +1,171 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dsyker interface {
Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int)
}
func DsyrkTest(t *testing.T, blasser Dsyker) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
n int
k int
alpha float64
a [][]float64
c [][]float64
beta float64
ans [][]float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 5, 6},
{0, 0, 9},
},
beta: 2,
ans: [][]float64{
{2, 4, 6},
{0, 10, 12},
{0, 0, 18},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 0,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
beta: 2,
ans: [][]float64{
{2, 0, 0},
{4, 6, 0},
{8, 10, 12},
},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{17, 37, 57},
{0, 83, 127},
{0, 0, 195},
},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 2},
{3, 4},
{5, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{17, 0, 0},
{37, 83, 0},
{57, 127, 195},
},
},
{
ul: blas.Upper,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
c: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
beta: 2,
ans: [][]float64{
{17, 37, 57},
{0, 83, 127},
{0, 0, 195},
},
},
{
ul: blas.Lower,
tA: blas.Trans,
n: 3,
k: 2,
alpha: 3,
a: [][]float64{
{1, 3, 5},
{2, 4, 6},
},
c: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
beta: 2,
ans: [][]float64{
{17, 0, 0},
{37, 83, 0},
{57, 127, 195},
},
},
} {
aFlat := flatten(test.a)
cFlat := flatten(test.c)
ansFlat := flatten(test.ans)
blasser.Dsyrk(test.ul, test.tA, test.n, test.k, test.alpha, aFlat, len(test.a[0]), test.beta, cFlat, len(test.c[0]))
if !floats.EqualApprox(ansFlat, cFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, cFlat)
}
}
}

123
blas/testblas/dtbmv.go Normal file
View File

@@ -0,0 +1,123 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtbmver interface {
Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
}
func DtbmvTest(t *testing.T, blasser Dtbmver) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
d blas.Diag
n int
k int
a [][]float64
x []float64
ans []float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
n: 3,
k: 1,
a: [][]float64{
{1, 2, 0},
{0, 1, 4},
{0, 0, 1},
},
x: []float64{2, 3, 4},
ans: []float64{8, 19, 4},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 33, 10, 63, -5},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{7, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{0, 7, 2, 0, 0},
{0, 0, 1, 12, 0},
{0, 0, 0, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 15, 20, 51, 7},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{7, 3, 9, 0, 0},
{0, 6, 7, 10, 0},
{0, 0, 2, 1, 11},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{7, 15, 29, 71, 40},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{7, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{9, 7, 2, 0, 0},
{0, 10, 1, 12, 0},
{0, 0, 11, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
ans: []float64{40, 73, 65, 63, -5},
},
} {
extra := 0
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
incTest := func(incX, extra int) {
xnew := makeIncremented(test.x, incX, extra)
ans := makeIncremented(test.ans, incX, extra)
lda := test.k + 1
blasser.Dtbmv(test.ul, test.tA, test.d, test.n, test.k, aFlat, lda, xnew, incX)
if !dSliceTolEqual(ans, xnew) {
t.Errorf("Case %v, Inc %v: Want %v, got %v", i, incX, ans, xnew)
}
}
incTest(1, extra)
incTest(3, extra)
incTest(-2, extra)
}
}

256
blas/testblas/dtbsv.go Normal file
View File

@@ -0,0 +1,256 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtbsver interface {
Dtbsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtbsvTest(t *testing.T, blasser Dtbsver) {
for i, test := range []struct {
ul blas.Uplo
tA blas.Transpose
d blas.Diag
n, k int
a [][]float64
lda int
x []float64
incX int
ans []float64
}{
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{2.479166666666667, -0.493055555555556, 0.708333333333333, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 1,
a: [][]float64{
{1, 3, 0, 0, 0},
{0, 6, 7, 0, 0},
{0, 0, 2, 1, 0},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{2.479166666666667, -101, -0.493055555555556, -201, 0.708333333333333, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
{
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
},
{
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{1, -0.166666666666667, -0.416666666666667, 0.437500000000000, -5.770833333333334},
},
{
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 3, 5, 0, 0},
{0, 6, 7, 5, 0},
{0, 0, 2, 1, 5},
{0, 0, 0, 12, 3},
{0, 0, 0, 0, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{1, -101, -0.166666666666667, -201, -0.416666666666667, -301, 0.437500000000000, -401, -5.770833333333334, -501, -601, -701},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, 2, 3, 4, 5},
incX: 1,
ans: []float64{-15.854166666666664, -16.395833333333336, 13.208333333333334, 1.583333333333333, -5.000000000000000},
},
{
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
n: 5,
k: 2,
a: [][]float64{
{1, 0, 0, 0, 0},
{3, 6, 0, 0, 0},
{5, 7, 2, 0, 0},
{0, 5, 1, 12, 0},
{0, 0, 5, 3, -1},
},
x: []float64{1, -101, 2, -201, 3, -301, 4, -401, 5, -501, -601, -701},
incX: 2,
ans: []float64{-15.854166666666664, -101, -16.395833333333336, -201, 13.208333333333334, -301, 1.583333333333333, -401, -5.000000000000000, -501, -601, -701},
},
} {
var aFlat []float64
if test.ul == blas.Upper {
aFlat = flattenBanded(test.a, test.k, 0)
} else {
aFlat = flattenBanded(test.a, 0, test.k)
}
xCopy := sliceCopy(test.x)
// TODO: Have tests where the banded matrix is constructed explicitly
// to allow testing for lda =! k+1
blasser.Dtbsv(test.ul, test.tA, test.d, test.n, test.k, aFlat, test.k+1, xCopy, test.incX)
if !dSliceTolEqual(test.ans, xCopy) {
t.Errorf("Case %v: Want %v, got %v", i, test.ans, xCopy)
}
}
/*
// TODO: Uncomment when Dtrsv is fixed
// Compare with dense for larger matrices
for _, ul := range [...]blas.Uplo{blas.Upper, blas.Lower} {
for _, tA := range [...]blas.Transpose{blas.NoTrans, blas.Trans} {
for _, n := range [...]int{7, 8, 11} {
for _, d := range [...]blas.Diag{blas.NonUnit, blas.Unit} {
for _, k := range [...]int{0, 1, 3} {
for _, incX := range [...]int{1, 3} {
a := make([][]float64, n)
for i := range a {
a[i] = make([]float64, n)
for j := range a[i] {
a[i][j] = rand.Float64()
}
}
x := make([]float64, n)
for i := range x {
x[i] = rand.Float64()
}
extra := 3
xinc := makeIncremented(x, incX, extra)
bandX := sliceCopy(xinc)
var aFlatBand []float64
if ul == blas.Upper {
aFlatBand = flattenBanded(a, k, 0)
} else {
aFlatBand = flattenBanded(a, 0, k)
}
blasser.Dtbsv(ul, tA, d, n, k, aFlatBand, k+1, bandX, incX)
aFlatDense := flatten(a)
denseX := sliceCopy(xinc)
blasser.Dtrsv(ul, tA, d, n, aFlatDense, n, denseX, incX)
if !dSliceTolEqual(denseX, bandX) {
t.Errorf("Case %v: dense banded mismatch")
}
}
}
}
}
}
}
*/
}

129
blas/testblas/dtpmv.go Normal file
View File

@@ -0,0 +1,129 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtpmver interface {
Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int)
}
func DtpmvTest(t *testing.T, blasser Dtpmver) {
for i, test := range []struct {
n int
a [][]float64
x []float64
d blas.Diag
ul blas.Uplo
tA blas.Transpose
ans []float64
}{
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{74, 86, 65},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{62, 54, 5},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{6, 1, 0},
{7, 10, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 6, 7},
{0, 1, 10},
{0, 0, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.Trans,
ans: []float64{74, 86, 65},
},
} {
incTest := func(incX, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
blasser.Dtpmv(test.ul, test.tA, test.d, test.n, aFlat, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-3, 3)
incTest(4, 3)
}
}

144
blas/testblas/dtpsv.go Normal file
View File

@@ -0,0 +1,144 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtpsver interface {
Dtpsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, ap []float64, x []float64, incX int)
}
func DtpsvTest(t *testing.T, blasser Dtpsver) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
tA blas.Transpose
d blas.Diag
x []float64
ans []float64
}{
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 1, 15},
{0, 0, 1},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 1, 0},
{3, 15, 1},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
} {
incTest := func(incX, extra int) {
aFlat := flattenTriangular(test.a, test.ul)
x := makeIncremented(test.x, incX, extra)
blasser.Dtpsv(test.ul, test.tA, test.d, test.n, aFlat, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, incX = %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-2, 0)
incTest(3, 0)
incTest(-3, 8)
incTest(4, 2)
}
}

806
blas/testblas/dtrmm.go Normal file
View File

@@ -0,0 +1,806 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrmmer interface {
Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int)
}
func DtrmmTest(t *testing.T, blasser Dtrmmer) {
for i, test := range []struct {
s blas.Side
ul blas.Uplo
tA blas.Transpose
d blas.Diag
m int
n int
alpha float64
a [][]float64
b [][]float64
ans [][]float64
}{
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{588, 624, 660},
{598, 632, 666},
{380, 400, 420},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{130, 140, 150},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{484, 512, 540},
{374, 394, 414},
{38, 40, 42},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{26, 28, 30},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
{472, 506, 540},
{930, 990, 1050},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
{248, 268, 288},
{588, 630, 672},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
{472, 506, 540},
{930, 990, 1050},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{170, 184, 198},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3, 4},
{0, 5, 6, 7},
{0, 0, 8, 9},
{0, 0, 0, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
{248, 268, 288},
{588, 630, 672},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2},
{0, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 22, 24},
{66, 72, 78},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{588, 624, 660},
{598, 632, 666},
{380, 400, 420},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{130, 140, 150},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0, 0},
{2, 5, 0, 0},
{3, 6, 8, 0},
{4, 7, 9, 10},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{320, 340, 360},
{484, 512, 540},
{374, 394, 414},
{38, 40, 42},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0},
{2, 5},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{72, 78, 84},
{26, 28, 30},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
{32, 200, 482},
{38, 236, 566},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
{32, 98, 302},
{38, 116, 356},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
{208, 316, 216},
{244, 370, 252},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
{208, 214, 36},
{244, 250, 42},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
{208, 316, 216},
{244, 370, 252},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 208, 144},
{172, 262, 180},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
{208, 214, 36},
{244, 250, 42},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{136, 142, 24},
{172, 178, 30},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
{32, 200, 482},
{38, 236, 566},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 128, 314},
{26, 164, 398},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
{32, 98, 302},
{38, 116, 356},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 2,
a: [][]float64{
{1, 0, 0},
{2, 4, 0},
{3, 5, 6},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{20, 62, 194},
{26, 80, 248},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
ansFlat := flatten(test.ans)
blasser.Dtrmm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, aFlat, len(test.a[0]), bFlat, len(test.b[0]))
if !floats.EqualApprox(ansFlat, bFlat, 1e-14) {
t.Errorf("Case %v. Want %v, got %v.", i, ansFlat, bFlat)
}
}
}

147
blas/testblas/dtrmv.go Normal file
View File

@@ -0,0 +1,147 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrmver interface {
Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtrmvTest(t *testing.T, blasser Dtrmver) {
for i, test := range []struct {
n int
a [][]float64
x []float64
d blas.Diag
ul blas.Uplo
tA blas.Transpose
ans []float64
}{
{
n: 1,
a: [][]float64{{5}},
x: []float64{2},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{10},
},
{
n: 1,
a: [][]float64{{5}},
x: []float64{2},
d: blas.Unit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{2},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{74, 86, 65},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.NoTrans,
ans: []float64{62, 54, 5},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{6, 1, 0},
{7, 10, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Lower,
tA: blas.NoTrans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 6, 7},
{0, 9, 10},
{0, 0, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{15, 54, 126},
},
{
n: 3,
a: [][]float64{
{1, 6, 7},
{0, 1, 10},
{0, 0, 1},
},
x: []float64{3, 4, 5},
d: blas.Unit,
ul: blas.Upper,
tA: blas.Trans,
ans: []float64{3, 22, 66},
},
{
n: 3,
a: [][]float64{
{5, 0, 0},
{6, 9, 0},
{7, 10, 13},
},
x: []float64{3, 4, 5},
d: blas.NonUnit,
ul: blas.Lower,
tA: blas.Trans,
ans: []float64{74, 86, 65},
},
} {
incTest := func(incX, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
blasser.Dtrmv(test.ul, test.tA, test.d, test.n, aFlat, test.n, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, idx %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 3)
incTest(-3, 3)
incTest(4, 3)
}
}

View File

@@ -0,0 +1,30 @@
// Copyright ©2015 The gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package testblas
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func DtrmvBenchmark(b *testing.B, dtrmv Dtrmver, n, lda, incX int, ul blas.Uplo, tA blas.Transpose, d blas.Diag) {
rnd := rand.New(rand.NewSource(0))
a := make([]float64, n*lda)
for i := range a {
a[i] = rnd.Float64()
}
x := make([]float64, n*incX)
for i := range x {
x[i] = rnd.Float64()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
dtrmv.Dtrmv(ul, tA, d, n, a, lda, x, incX)
}
}

811
blas/testblas/dtrsm.go Normal file
View File

@@ -0,0 +1,811 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrsmer interface {
Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int,
alpha float64, a []float64, lda int, b []float64, ldb int)
}
func DtrsmTest(t *testing.T, blasser Dtrsmer) {
for i, test := range []struct {
s blas.Side
ul blas.Uplo
tA blas.Transpose
d blas.Diag
m int
n int
alpha float64
a [][]float64
b [][]float64
ans [][]float64
}{
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{1, 3.4},
{-0.5, -0.5},
{2, 3.2},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{60, 96},
{-42, -66},
{10, 16},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{1, 3.4, 1.2, 13},
{-0.5, -0.5, -4, -3.5},
{2, 3.2, 3.6, 4},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 2,
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 5},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{60, 96, 126, 146},
{-42, -66, -88, -94},
{10, 16, 18, 20},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{4.5, 9},
{-0.375, -1.5},
{-0.75, -12.0 / 7},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{9, 18},
{-15, -33},
{60, 132},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{4.5, 9, 3, 13.5},
{-0.375, -1.5, -1.5, -63.0 / 8},
{-0.75, -12.0 / 7, 3, 39.0 / 28},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 7},
},
b: [][]float64{
{3, 6, 2, 9},
{4, 7, 1, 3},
{5, 8, 9, 10},
},
ans: [][]float64{
{9, 18, 6, 27},
{-15, -33, -15, -72},
{60, 132, 87, 327},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{4.5, 9},
{-0.30, -1.2},
{-6.0 / 35, -24.0 / 35},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{9, 18},
{-15, -33},
{69, 150},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6, 6, 7},
{4, 7, 8, 9},
{5, 8, 10, 11},
},
ans: [][]float64{
{4.5, 9, 9, 10.5},
{-0.3, -1.2, -0.6, -0.9},
{-6.0 / 35, -24.0 / 35, -12.0 / 35, -18.0 / 35},
},
},
{
s: blas.Left,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{3, 6, 6, 7},
{4, 7, 8, 9},
{5, 8, 10, 11},
},
ans: [][]float64{
{9, 18, 18, 21},
{-15, -33, -30, -36},
{69, 150, 138, 165},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{-0.46875, 0.375},
{0.1875, 0.75},
{1.875, 3},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 2,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6},
{4, 7},
{5, 8},
},
ans: [][]float64{
{168, 267},
{-78, -123},
{15, 24},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6, 2, 3},
{4, 7, 4, 5},
{5, 8, 6, 7},
},
ans: [][]float64{
{-0.46875, 0.375, -2.0625, -1.78125},
{0.1875, 0.75, -0.375, -0.1875},
{1.875, 3, 2.25, 2.625},
},
},
{
s: blas.Left,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 3,
n: 4,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 4, 0},
{5, 6, 8},
},
b: [][]float64{
{3, 6, 2, 3},
{4, 7, 4, 5},
{5, 8, 6, 7},
},
ans: [][]float64{
{168, 267, 204, 237},
{-78, -123, -96, -111},
{15, 24, 18, 21},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{15, -2.4, -48.0 / 35},
{19.5, -3.3, -66.0 / 35},
{24, -4.2, -2.4},
{28.5, -5.1, -102.0 / 35},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
{48, -93, 420},
{57, -111, 501},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 7},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{15, -2.4, -48.0 / 35},
{19.5, -3.3, -66.0 / 35},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
{7.35, 2.1, 6.75},
{8.925, 2.55, 7.875},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
{651, -273, 54},
{759, -318, 63},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
{7.35, 2.1, 6.75},
{8.925, 2.55, 7.875},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
{651, -273, 54},
{759, -318, 63},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{4.2, 1.2, 4.5},
{5.775, 1.65, 5.625},
},
},
{
s: blas.Right,
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 3, 4},
{0, 5, 6},
{0, 0, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{435, -183, 36},
{543, -228, 45},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{15, -2.4, -1.2},
{19.5, -3.3, -1.65},
{24, -4.2, -2.1},
{28.5, -5.1, -2.55},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 4,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
{16, 17, 18},
{19, 20, 21},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
{48, -93, 420},
{57, -111, 501},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{15, -2.4, -1.2},
{19.5, -3.3, -1.65},
},
},
{
s: blas.Right,
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
m: 2,
n: 3,
alpha: 3,
a: [][]float64{
{2, 0, 0},
{3, 5, 0},
{4, 6, 8},
},
b: [][]float64{
{10, 11, 12},
{13, 14, 15},
},
ans: [][]float64{
{30, -57, 258},
{39, -75, 339},
},
},
} {
aFlat := flatten(test.a)
bFlat := flatten(test.b)
ansFlat := flatten(test.ans)
var lda int
if test.s == blas.Left {
lda = test.m
} else {
lda = test.n
}
blasser.Dtrsm(test.s, test.ul, test.tA, test.d, test.m, test.n, test.alpha, aFlat, lda, bFlat, test.n)
if !floats.EqualApprox(ansFlat, bFlat, 1e-13) {
t.Errorf("Case %v: Want %v, got %v.", i, ansFlat, bFlat)
}
}
}

144
blas/testblas/dtrsv.go Normal file
View File

@@ -0,0 +1,144 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
"github.com/gonum/floats"
)
type Dtrsver interface {
Dtrsv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
}
func DtrsvTest(t *testing.T, blasser Dtrsver) {
for i, test := range []struct {
n int
a [][]float64
ul blas.Uplo
tA blas.Transpose
d blas.Diag
x []float64
ans []float64
}{
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 1, 15},
{0, 0, 1},
},
ul: blas.Upper,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.NoTrans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{5, -0.5, -0.0625},
},
{
n: 3,
a: [][]float64{
{1, 2, 3},
{0, 8, 15},
{0, 0, 8},
},
ul: blas.Upper,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{5, -4, 52},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 8, 0},
{3, 15, 8},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.NonUnit,
x: []float64{5, 6, 7},
ans: []float64{4.15625, -0.890625, 0.875},
},
{
n: 3,
a: [][]float64{
{1, 0, 0},
{2, 1, 0},
{3, 15, 1},
},
ul: blas.Lower,
tA: blas.Trans,
d: blas.Unit,
x: []float64{5, 6, 7},
ans: []float64{182, -99, 7},
},
} {
incTest := func(incX, extra int) {
aFlat := flatten(test.a)
x := makeIncremented(test.x, incX, extra)
blasser.Dtrsv(test.ul, test.tA, test.d, test.n, aFlat, test.n, x, incX)
ans := makeIncremented(test.ans, incX, extra)
if !floats.EqualApprox(x, ans, 1e-14) {
t.Errorf("Case %v, incX = %v: Want %v, got %v.", i, incX, ans, x)
}
}
incTest(1, 0)
incTest(-2, 0)
incTest(3, 0)
incTest(-3, 8)
incTest(4, 2)
}
}

145
blas/testblas/dtxmv.go Normal file
View File

@@ -0,0 +1,145 @@
package testblas
import (
"testing"
"github.com/gonum/blas"
)
type Dtxmver interface {
Dtrmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int)
Dtbmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n, k int, a []float64, lda int, x []float64, incX int)
Dtpmv(ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, x []float64, incX int)
}
type vec struct {
data []float64
inc int
}
var cases = []struct {
n, k int
ul blas.Uplo
d blas.Diag
ldab int
tr, tb, tp []float64
ins []vec
solNoTrans []float64
solTrans []float64
}{
{
n: 3,
k: 1,
ul: blas.Upper,
d: blas.NonUnit,
tr: []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
tb: []float64{1, 2, 3, 4, 5, 0},
ldab: 2,
tp: []float64{1, 2, 0, 3, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{8, 25, 20},
solTrans: []float64{2, 13, 32},
},
{
n: 3,
k: 1,
ul: blas.Upper,
d: blas.Unit,
tr: []float64{1, 2, 0, 0, 3, 4, 0, 0, 5},
tb: []float64{1, 2, 3, 4, 5, 0},
ldab: 2,
tp: []float64{1, 2, 0, 3, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{8, 19, 4},
solTrans: []float64{2, 7, 16},
},
{
n: 3,
k: 1,
ul: blas.Lower,
d: blas.NonUnit,
tr: []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
tb: []float64{0, 1, 2, 3, 4, 5},
ldab: 2,
tp: []float64{1, 2, 3, 0, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{2, 13, 32},
solTrans: []float64{8, 25, 20},
},
{
n: 3,
k: 1,
ul: blas.Lower,
d: blas.Unit,
tr: []float64{1, 0, 0, 2, 3, 0, 0, 4, 5},
tb: []float64{0, 1, 2, 3, 4, 5},
ldab: 2,
tp: []float64{1, 2, 3, 0, 4, 5},
ins: []vec{
{[]float64{2, 3, 4}, 1},
{[]float64{2, 1, 3, 1, 4}, 2},
{[]float64{4, 1, 3, 1, 2}, -2},
},
solNoTrans: []float64{2, 7, 16},
solTrans: []float64{8, 19, 4},
},
}
func DtxmvTest(t *testing.T, blasser Dtxmver) {
for nc, c := range cases {
for nx, x := range c.ins {
in := make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtrmv(c.ul, blas.NoTrans, c.d, c.n, c.tr, c.n, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtrmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtrmv(c.ul, blas.Trans, c.d, c.n, c.tr, c.n, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtrmv result for: Trans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtbmv(c.ul, blas.NoTrans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtbmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtbmv(c.ul, blas.Trans, c.d, c.n, c.k, c.tb, c.ldab, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtbmv result for: Trans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtpmv(c.ul, blas.NoTrans, c.d, c.n, c.tp, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solNoTrans, 1) {
t.Error("Wrong Dtpmv result for: NoTrans in Case:", nc, "input:", nx)
}
in = make([]float64, len(x.data))
copy(in, x.data)
blasser.Dtpmv(c.ul, blas.Trans, c.d, c.n, c.tp, in, x.inc)
if !dStridedSliceTolEqual(c.n, in, x.inc, c.solTrans, 1) {
t.Error("Wrong Dtpmv result for: Trans in Case:", nc, "input:", nx)
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
package testblas
import (
"math/rand"
"testing"
"github.com/gonum/blas"
)
func DgemvBenchmark(b *testing.B, blasser Dgemver, tA blas.Transpose, m, n, incX, incY int) {
var lenX, lenY int
if tA == blas.NoTrans {
lenX = n
lenY = m
} else {
lenX = m
lenY = n
}
xr := make([]float64, lenX)
for i := range xr {
xr[i] = rand.Float64()
}
x := makeIncremented(xr, incX, 0)
yr := make([]float64, lenY)
for i := range yr {
yr[i] = rand.Float64()
}
y := makeIncremented(yr, incY, 0)
a := make([]float64, m*n)
for i := range a {
a[i] = rand.Float64()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
blasser.Dgemv(tA, m, n, 2, a, n, x, incX, 3, y, incY)
}
}
func DgerBenchmark(b *testing.B, blasser Dgerer, m, n, incX, incY int) {
xr := make([]float64, m)
for i := range xr {
xr[i] = rand.Float64()
}
x := makeIncremented(xr, incX, 0)
yr := make([]float64, n)
for i := range yr {
yr[i] = rand.Float64()
}
y := makeIncremented(yr, incY, 0)
a := make([]float64, m*n)
for i := range a {
a[i] = rand.Float64()
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
blasser.Dger(m, n, 2, x, incX, y, incY, a, n)
}
}