mat: improve mat element shadowing detection

This commit is contained in:
Dan Kortschak
2018-05-21 13:50:56 +09:30
committed by Dan Kortschak
parent d684d065a3
commit 5f11fd92d7
5 changed files with 77 additions and 3 deletions

View File

@@ -485,6 +485,7 @@ func (m *Dense) Copy(a Matrix) (r, c int) {
// Nothing to do. // Nothing to do.
} }
default: default:
m.checkOverlapMatrix(aU)
for i := 0; i < r; i++ { for i := 0; i < r; i++ {
for j := 0; j < c; j++ { for j := 0; j < c; j++ {
m.set(i, j, a.At(i, j)) m.set(i, j, a.At(i, j))

View File

@@ -43,6 +43,8 @@ func (m *Dense) Add(a, b Matrix) {
} }
} }
m.checkOverlapMatrix(aU)
m.checkOverlapMatrix(bU)
var restore func() var restore func()
if m == aU { if m == aU {
m, restore = m.isolatedWorkspace(aU) m, restore = m.isolatedWorkspace(aU)
@@ -90,6 +92,8 @@ func (m *Dense) Sub(a, b Matrix) {
} }
} }
m.checkOverlapMatrix(aU)
m.checkOverlapMatrix(bU)
var restore func() var restore func()
if m == aU { if m == aU {
m, restore = m.isolatedWorkspace(aU) m, restore = m.isolatedWorkspace(aU)
@@ -138,6 +142,8 @@ func (m *Dense) MulElem(a, b Matrix) {
} }
} }
m.checkOverlapMatrix(aU)
m.checkOverlapMatrix(bU)
var restore func() var restore func()
if m == aU { if m == aU {
m, restore = m.isolatedWorkspace(aU) m, restore = m.isolatedWorkspace(aU)
@@ -186,6 +192,8 @@ func (m *Dense) DivElem(a, b Matrix) {
} }
} }
m.checkOverlapMatrix(aU)
m.checkOverlapMatrix(bU)
var restore func() var restore func()
if m == aU { if m == aU {
m, restore = m.isolatedWorkspace(aU) m, restore = m.isolatedWorkspace(aU)
@@ -429,6 +437,8 @@ func (m *Dense) Mul(a, b Matrix) {
} }
} }
m.checkOverlapMatrix(aU)
m.checkOverlapMatrix(bU)
row := getFloats(ac, false) row := getFloats(ac, false)
defer putFloats(row) defer putFloats(row)
for r := 0; r < ar; r++ { for r := 0; r < ar; r++ {
@@ -699,6 +709,7 @@ func (m *Dense) Scale(f float64, a Matrix) {
return return
} }
m.checkOverlapMatrix(a)
for r := 0; r < ar; r++ { for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ { for c := 0; c < ac; c++ {
m.set(r, c, f*a.At(r, c)) m.set(r, c, f*a.At(r, c))
@@ -738,6 +749,7 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
return return
} }
m.checkOverlapMatrix(a)
for r := 0; r < ar; r++ { for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ { for c := 0; c < ac; c++ {
m.set(r, c, fn(r, c, a.At(r, c))) m.set(r, c, fn(r, c, a.At(r, c)))
@@ -845,6 +857,7 @@ func (m *Dense) Outer(alpha float64, x, y Vector) {
if rv, ok := xU.(RawVectorer); ok { if rv, ok := xU.(RawVectorer); ok {
xmat = rv.RawVector() xmat = rv.RawVector()
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral()) m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
} else { } else {
fast = false fast = false
} }

View File

@@ -70,10 +70,46 @@ func (m *Dense) checkOverlap(a blas64.General) bool {
return checkOverlap(m.RawMatrix(), a) return checkOverlap(m.RawMatrix(), a)
} }
func (m *Dense) checkOverlapMatrix(a Matrix) bool {
if m == a {
return false
}
var amat blas64.General
switch a := a.(type) {
default:
return false
case RawMatrixer:
amat = a.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(a.RawSymmetric())
case RawTriangular:
amat = generalFromTriangular(a.RawTriangular())
}
return m.checkOverlap(amat)
}
func (s *SymDense) checkOverlap(a blas64.General) bool { func (s *SymDense) checkOverlap(a blas64.General) bool {
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
} }
func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
if s == a {
return false
}
var amat blas64.General
switch a := a.(type) {
default:
return false
case RawMatrixer:
amat = a.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(a.RawSymmetric())
case RawTriangular:
amat = generalFromTriangular(a.RawTriangular())
}
return s.checkOverlap(amat)
}
// generalFromSymmetric returns a blas64.General with the backing // generalFromSymmetric returns a blas64.General with the backing
// data and dimensions of a. // data and dimensions of a.
func generalFromSymmetric(a blas64.Symmetric) blas64.General { func generalFromSymmetric(a blas64.Symmetric) blas64.General {
@@ -89,6 +125,24 @@ func (t *TriDense) checkOverlap(a blas64.General) bool {
return checkOverlap(generalFromTriangular(t.RawTriangular()), a) return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
} }
func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
if t == a {
return false
}
var amat blas64.General
switch a := a.(type) {
default:
return false
case RawMatrixer:
amat = a.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(a.RawSymmetric())
case RawTriangular:
amat = generalFromTriangular(a.RawTriangular())
}
return t.checkOverlap(amat)
}
// generalFromTriangular returns a blas64.General with the backing // generalFromTriangular returns a blas64.General with the backing
// data and dimensions of a. // data and dimensions of a.
func generalFromTriangular(a blas64.Triangular) blas64.General { func generalFromTriangular(a blas64.Triangular) blas64.General {

View File

@@ -201,6 +201,8 @@ func (s *SymDense) AddSym(a, b Symmetric) {
} }
} }
s.checkOverlapMatrix(a)
s.checkOverlapMatrix(b)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n] stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
for j := i; j < n; j++ { for j := i; j < n; j++ {

View File

@@ -346,9 +346,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) {
// Note that matrix inversion is numerically unstable, and should generally be // Note that matrix inversion is numerically unstable, and should generally be
// avoided where possible, for example by using the Solve routines. // avoided where possible, for example by using the Solve routines.
func (t *TriDense) InverseTri(a Triangular) error { func (t *TriDense) InverseTri(a Triangular) error {
if rt, ok := a.(RawTriangular); ok { t.checkOverlapMatrix(a)
t.checkOverlap(generalFromTriangular(rt.RawTriangular()))
}
n, _ := a.Triangle() n, _ := a.Triangle()
t.reuseAs(a.Triangle()) t.reuseAs(a.Triangle())
t.Copy(a) t.Copy(a)
@@ -385,6 +383,8 @@ func (t *TriDense) MulTri(a, b Triangular) {
aU, _ := untransposeTri(a) aU, _ := untransposeTri(a)
bU, _ := untransposeTri(b) bU, _ := untransposeTri(b)
t.checkOverlapMatrix(bU)
t.checkOverlapMatrix(aU)
t.reuseAs(n, kind) t.reuseAs(n, kind)
var restore func() var restore func()
if t == aU { if t == aU {
@@ -430,6 +430,9 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
switch a := a.(type) { switch a := a.(type) {
case RawTriangular: case RawTriangular:
amat := a.RawTriangular() amat := a.RawTriangular()
if t != a {
t.checkOverlap(generalFromTriangular(amat))
}
if kind == Upper { if kind == Upper {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n] ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
@@ -449,6 +452,7 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
} }
return return
default: default:
t.checkOverlapMatrix(a)
isUpper := kind == Upper isUpper := kind == Upper
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if isUpper { if isUpper {