mirror of
https://github.com/gonum/gonum.git
synced 2025-11-03 11:21:14 +08:00
blas/gonum: handle special values of beta consistently in Dsyrk
This commit is contained in:
committed by
Vladimír Chalupecký
parent
f41a0d0905
commit
73c94a2aff
@@ -463,17 +463,31 @@ func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha flo
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc+i : i*ldc+n]
|
||||
atmp := a[i*lda : i*lda+k]
|
||||
for jc, vc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = vc*beta + alpha*f64.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
if beta == 0 {
|
||||
for jc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = alpha * f64.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
}
|
||||
} else {
|
||||
for jc, vc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = vc*beta + alpha*f64.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc : i*ldc+i+1]
|
||||
atmp := a[i*lda : i*lda+k]
|
||||
for j, vc := range c[i*ldc : i*ldc+i+1] {
|
||||
c[i*ldc+j] = vc*beta + alpha*f64.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
if beta == 0 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] = alpha * f64.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
}
|
||||
} else {
|
||||
for j, vc := range ctmp {
|
||||
ctmp[j] = vc*beta + alpha*f64.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -482,7 +496,11 @@ func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha flo
|
||||
if ul == blas.Upper {
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc+i : i*ldc+n]
|
||||
if beta != 1 {
|
||||
if beta == 0 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] = 0
|
||||
}
|
||||
} else if beta != 1 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] *= beta
|
||||
}
|
||||
|
||||
@@ -471,17 +471,31 @@ func (Implementation) Ssyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha flo
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc+i : i*ldc+n]
|
||||
atmp := a[i*lda : i*lda+k]
|
||||
for jc, vc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = vc*beta + alpha*f32.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
if beta == 0 {
|
||||
for jc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = alpha * f32.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
}
|
||||
} else {
|
||||
for jc, vc := range ctmp {
|
||||
j := jc + i
|
||||
ctmp[jc] = vc*beta + alpha*f32.DotUnitary(atmp, a[j*lda:j*lda+k])
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc : i*ldc+i+1]
|
||||
atmp := a[i*lda : i*lda+k]
|
||||
for j, vc := range c[i*ldc : i*ldc+i+1] {
|
||||
c[i*ldc+j] = vc*beta + alpha*f32.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
if beta == 0 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] = alpha * f32.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
}
|
||||
} else {
|
||||
for j, vc := range ctmp {
|
||||
ctmp[j] = vc*beta + alpha*f32.DotUnitary(a[j*lda:j*lda+k], atmp)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -490,7 +504,11 @@ func (Implementation) Ssyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha flo
|
||||
if ul == blas.Upper {
|
||||
for i := 0; i < n; i++ {
|
||||
ctmp := c[i*ldc+i : i*ldc+n]
|
||||
if beta != 1 {
|
||||
if beta == 0 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] = 0
|
||||
}
|
||||
} else if beta != 1 {
|
||||
for j := range ctmp {
|
||||
ctmp[j] *= beta
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user