mirror of
https://github.com/gonum/gonum.git
synced 2025-10-09 00:50:16 +08:00
mat: improve mat element shadowing detection
This commit is contained in:

committed by
Dan Kortschak

parent
d684d065a3
commit
5f11fd92d7
@@ -485,6 +485,7 @@ func (m *Dense) Copy(a Matrix) (r, c int) {
|
||||
// Nothing to do.
|
||||
}
|
||||
default:
|
||||
m.checkOverlapMatrix(aU)
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
m.set(i, j, a.At(i, j))
|
||||
|
@@ -43,6 +43,8 @@ func (m *Dense) Add(a, b Matrix) {
|
||||
}
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
m.checkOverlapMatrix(bU)
|
||||
var restore func()
|
||||
if m == aU {
|
||||
m, restore = m.isolatedWorkspace(aU)
|
||||
@@ -90,6 +92,8 @@ func (m *Dense) Sub(a, b Matrix) {
|
||||
}
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
m.checkOverlapMatrix(bU)
|
||||
var restore func()
|
||||
if m == aU {
|
||||
m, restore = m.isolatedWorkspace(aU)
|
||||
@@ -138,6 +142,8 @@ func (m *Dense) MulElem(a, b Matrix) {
|
||||
}
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
m.checkOverlapMatrix(bU)
|
||||
var restore func()
|
||||
if m == aU {
|
||||
m, restore = m.isolatedWorkspace(aU)
|
||||
@@ -186,6 +192,8 @@ func (m *Dense) DivElem(a, b Matrix) {
|
||||
}
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
m.checkOverlapMatrix(bU)
|
||||
var restore func()
|
||||
if m == aU {
|
||||
m, restore = m.isolatedWorkspace(aU)
|
||||
@@ -429,6 +437,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
}
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
m.checkOverlapMatrix(bU)
|
||||
row := getFloats(ac, false)
|
||||
defer putFloats(row)
|
||||
for r := 0; r < ar; r++ {
|
||||
@@ -699,6 +709,7 @@ func (m *Dense) Scale(f float64, a Matrix) {
|
||||
return
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(a)
|
||||
for r := 0; r < ar; r++ {
|
||||
for c := 0; c < ac; c++ {
|
||||
m.set(r, c, f*a.At(r, c))
|
||||
@@ -738,6 +749,7 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
|
||||
return
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(a)
|
||||
for r := 0; r < ar; r++ {
|
||||
for c := 0; c < ac; c++ {
|
||||
m.set(r, c, fn(r, c, a.At(r, c)))
|
||||
@@ -845,6 +857,7 @@ func (m *Dense) Outer(alpha float64, x, y Vector) {
|
||||
if rv, ok := xU.(RawVectorer); ok {
|
||||
xmat = rv.RawVector()
|
||||
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
|
@@ -70,10 +70,46 @@ func (m *Dense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(m.RawMatrix(), a)
|
||||
}
|
||||
|
||||
func (m *Dense) checkOverlapMatrix(a Matrix) bool {
|
||||
if m == a {
|
||||
return false
|
||||
}
|
||||
var amat blas64.General
|
||||
switch a := a.(type) {
|
||||
default:
|
||||
return false
|
||||
case RawMatrixer:
|
||||
amat = a.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(a.RawSymmetric())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(a.RawTriangular())
|
||||
}
|
||||
return m.checkOverlap(amat)
|
||||
}
|
||||
|
||||
func (s *SymDense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
|
||||
}
|
||||
|
||||
func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
|
||||
if s == a {
|
||||
return false
|
||||
}
|
||||
var amat blas64.General
|
||||
switch a := a.(type) {
|
||||
default:
|
||||
return false
|
||||
case RawMatrixer:
|
||||
amat = a.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(a.RawSymmetric())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(a.RawTriangular())
|
||||
}
|
||||
return s.checkOverlap(amat)
|
||||
}
|
||||
|
||||
// generalFromSymmetric returns a blas64.General with the backing
|
||||
// data and dimensions of a.
|
||||
func generalFromSymmetric(a blas64.Symmetric) blas64.General {
|
||||
@@ -89,6 +125,24 @@ func (t *TriDense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
|
||||
}
|
||||
|
||||
func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
|
||||
if t == a {
|
||||
return false
|
||||
}
|
||||
var amat blas64.General
|
||||
switch a := a.(type) {
|
||||
default:
|
||||
return false
|
||||
case RawMatrixer:
|
||||
amat = a.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(a.RawSymmetric())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(a.RawTriangular())
|
||||
}
|
||||
return t.checkOverlap(amat)
|
||||
}
|
||||
|
||||
// generalFromTriangular returns a blas64.General with the backing
|
||||
// data and dimensions of a.
|
||||
func generalFromTriangular(a blas64.Triangular) blas64.General {
|
||||
|
@@ -201,6 +201,8 @@ func (s *SymDense) AddSym(a, b Symmetric) {
|
||||
}
|
||||
}
|
||||
|
||||
s.checkOverlapMatrix(a)
|
||||
s.checkOverlapMatrix(b)
|
||||
for i := 0; i < n; i++ {
|
||||
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
|
||||
for j := i; j < n; j++ {
|
||||
|
@@ -346,9 +346,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) {
|
||||
// Note that matrix inversion is numerically unstable, and should generally be
|
||||
// avoided where possible, for example by using the Solve routines.
|
||||
func (t *TriDense) InverseTri(a Triangular) error {
|
||||
if rt, ok := a.(RawTriangular); ok {
|
||||
t.checkOverlap(generalFromTriangular(rt.RawTriangular()))
|
||||
}
|
||||
t.checkOverlapMatrix(a)
|
||||
n, _ := a.Triangle()
|
||||
t.reuseAs(a.Triangle())
|
||||
t.Copy(a)
|
||||
@@ -385,6 +383,8 @@ func (t *TriDense) MulTri(a, b Triangular) {
|
||||
|
||||
aU, _ := untransposeTri(a)
|
||||
bU, _ := untransposeTri(b)
|
||||
t.checkOverlapMatrix(bU)
|
||||
t.checkOverlapMatrix(aU)
|
||||
t.reuseAs(n, kind)
|
||||
var restore func()
|
||||
if t == aU {
|
||||
@@ -430,6 +430,9 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
||||
switch a := a.(type) {
|
||||
case RawTriangular:
|
||||
amat := a.RawTriangular()
|
||||
if t != a {
|
||||
t.checkOverlap(generalFromTriangular(amat))
|
||||
}
|
||||
if kind == Upper {
|
||||
for i := 0; i < n; i++ {
|
||||
ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
|
||||
@@ -449,6 +452,7 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
||||
}
|
||||
return
|
||||
default:
|
||||
t.checkOverlapMatrix(a)
|
||||
isUpper := kind == Upper
|
||||
for i := 0; i < n; i++ {
|
||||
if isUpper {
|
||||
|
Reference in New Issue
Block a user