mirror of
https://github.com/gonum/gonum.git
synced 2025-10-19 13:35:51 +08:00
blas/native: run go generate
This commit is contained in:
@@ -7,77 +7,15 @@
|
|||||||
package native
|
package native
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
math "gonum.org/v1/gonum/blas/native/internal/math32"
|
math "gonum.org/v1/gonum/blas/native/internal/math32"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newGeneral32(r, c int) general32 {
|
|
||||||
return general32{
|
|
||||||
data: make([]float32, r*c),
|
|
||||||
rows: r,
|
|
||||||
cols: c,
|
|
||||||
stride: c,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type general32 struct {
|
type general32 struct {
|
||||||
data []float32
|
data []float32
|
||||||
rows, cols int
|
rows, cols int
|
||||||
stride int
|
stride int
|
||||||
}
|
}
|
||||||
|
|
||||||
// adds element-wise into receiver. rows and columns must match
|
|
||||||
func (g general32) add(h general32) {
|
|
||||||
if debug {
|
|
||||||
if g.rows != h.rows {
|
|
||||||
panic("blas: row size mismatch")
|
|
||||||
}
|
|
||||||
if g.cols != h.cols {
|
|
||||||
panic("blas: col size mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := 0; i < g.rows; i++ {
|
|
||||||
gtmp := g.data[i*g.stride : i*g.stride+g.cols]
|
|
||||||
for j, v := range h.data[i*h.stride : i*h.stride+h.cols] {
|
|
||||||
gtmp[j] += v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// at returns the value at the ith row and jth column. For speed reasons, the
|
|
||||||
// rows and columns are not bounds checked.
|
|
||||||
func (g general32) at(i, j int) float32 {
|
|
||||||
if debug {
|
|
||||||
if i < 0 || i >= g.rows {
|
|
||||||
panic("blas: row out of bounds")
|
|
||||||
}
|
|
||||||
if j < 0 || j >= g.cols {
|
|
||||||
panic("blas: col out of bounds")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return g.data[i*g.stride+j]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g general32) check(c byte) error {
|
|
||||||
if g.rows < 0 {
|
|
||||||
return errors.New("blas: rows < 0")
|
|
||||||
}
|
|
||||||
if g.cols < 0 {
|
|
||||||
return errors.New("blas: cols < 0")
|
|
||||||
}
|
|
||||||
if g.stride < 1 {
|
|
||||||
return errors.New("blas: stride < 1")
|
|
||||||
}
|
|
||||||
if g.stride < g.cols {
|
|
||||||
return errors.New("blas: illegal stride")
|
|
||||||
}
|
|
||||||
if (g.rows-1)*g.stride+g.cols > len(g.data) {
|
|
||||||
return fmt.Errorf("blas: index of %c out of range", c)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g general32) clone() general32 {
|
func (g general32) clone() general32 {
|
||||||
data := make([]float32, len(g.data))
|
data := make([]float32, len(g.data))
|
||||||
copy(data, g.data)
|
copy(data, g.data)
|
||||||
@@ -89,21 +27,6 @@ func (g general32) clone() general32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// assumes they are the same size
|
|
||||||
func (g general32) copy(h general32) {
|
|
||||||
if debug {
|
|
||||||
if g.rows != h.rows {
|
|
||||||
panic("blas: row mismatch")
|
|
||||||
}
|
|
||||||
if g.cols != h.cols {
|
|
||||||
panic("blas: col mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for k := 0; k < g.rows; k++ {
|
|
||||||
copy(g.data[k*g.stride:(k+1)*g.stride], h.data[k*h.stride:(k+1)*h.stride])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g general32) equal(a general32) bool {
|
func (g general32) equal(a general32) bool {
|
||||||
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
|
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
|
||||||
return false
|
return false
|
||||||
@@ -116,34 +39,6 @@ func (g general32) equal(a general32) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
// print is to aid debugging. Commented out to avoid fmt import
|
|
||||||
func (g general32) print() {
|
|
||||||
fmt.Println("r = ", g.rows, "c = ", g.cols, "stride: ", g.stride)
|
|
||||||
for i := 0; i < g.rows; i++ {
|
|
||||||
fmt.Println(g.data[i*g.stride : (i+1)*g.stride])
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
func (g general32) view(i, j, r, c int) general32 {
|
|
||||||
if debug {
|
|
||||||
if i < 0 || i+r > g.rows {
|
|
||||||
panic("blas: row out of bounds")
|
|
||||||
}
|
|
||||||
if j < 0 || j+c > g.cols {
|
|
||||||
panic("blas: col out of bounds")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return general32{
|
|
||||||
data: g.data[i*g.stride+j : (i+r-1)*g.stride+j+c],
|
|
||||||
rows: r,
|
|
||||||
cols: c,
|
|
||||||
stride: g.stride,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g general32) equalWithinAbs(a general32, tol float32) bool {
|
func (g general32) equalWithinAbs(a general32, tol float32) bool {
|
||||||
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
|
if g.rows != a.rows || g.cols != a.cols || g.stride != a.stride {
|
||||||
return false
|
return false
|
||||||
|
@@ -587,7 +587,6 @@ func (Implementation) Srotm(n int, x []float32, incX int, y []float32, incY int,
|
|||||||
ix += incX
|
ix += incX
|
||||||
iy += incY
|
iy += incY
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sscal scales x by alpha.
|
// Sscal scales x by alpha.
|
||||||
|
@@ -1552,7 +1552,6 @@ func (Implementation) Ssbmv(ul blas.Uplo, n, k int, alpha float32, a []float32,
|
|||||||
ix += incX
|
ix += incX
|
||||||
iy += incY
|
iy += incY
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ssyr performs the rank-one update
|
// Ssyr performs the rank-one update
|
||||||
@@ -1744,7 +1743,6 @@ func (Implementation) Ssyr2(ul blas.Uplo, n int, alpha float32, x []float32, inc
|
|||||||
ix += incX
|
ix += incX
|
||||||
iy += incY
|
iy += incY
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stpsv solves
|
// Stpsv solves
|
||||||
|
Reference in New Issue
Block a user