mirror of
https://github.com/gonum/gonum.git
synced 2025-10-09 00:50:16 +08:00
mat: improve mat element shadowing detection
This commit is contained in:

committed by
Dan Kortschak

parent
d684d065a3
commit
5f11fd92d7
@@ -485,6 +485,7 @@ func (m *Dense) Copy(a Matrix) (r, c int) {
|
|||||||
// Nothing to do.
|
// Nothing to do.
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
for i := 0; i < r; i++ {
|
for i := 0; i < r; i++ {
|
||||||
for j := 0; j < c; j++ {
|
for j := 0; j < c; j++ {
|
||||||
m.set(i, j, a.At(i, j))
|
m.set(i, j, a.At(i, j))
|
||||||
|
@@ -43,6 +43,8 @@ func (m *Dense) Add(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
|
m.checkOverlapMatrix(bU)
|
||||||
var restore func()
|
var restore func()
|
||||||
if m == aU {
|
if m == aU {
|
||||||
m, restore = m.isolatedWorkspace(aU)
|
m, restore = m.isolatedWorkspace(aU)
|
||||||
@@ -90,6 +92,8 @@ func (m *Dense) Sub(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
|
m.checkOverlapMatrix(bU)
|
||||||
var restore func()
|
var restore func()
|
||||||
if m == aU {
|
if m == aU {
|
||||||
m, restore = m.isolatedWorkspace(aU)
|
m, restore = m.isolatedWorkspace(aU)
|
||||||
@@ -138,6 +142,8 @@ func (m *Dense) MulElem(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
|
m.checkOverlapMatrix(bU)
|
||||||
var restore func()
|
var restore func()
|
||||||
if m == aU {
|
if m == aU {
|
||||||
m, restore = m.isolatedWorkspace(aU)
|
m, restore = m.isolatedWorkspace(aU)
|
||||||
@@ -186,6 +192,8 @@ func (m *Dense) DivElem(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
|
m.checkOverlapMatrix(bU)
|
||||||
var restore func()
|
var restore func()
|
||||||
if m == aU {
|
if m == aU {
|
||||||
m, restore = m.isolatedWorkspace(aU)
|
m, restore = m.isolatedWorkspace(aU)
|
||||||
@@ -429,6 +437,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(aU)
|
||||||
|
m.checkOverlapMatrix(bU)
|
||||||
row := getFloats(ac, false)
|
row := getFloats(ac, false)
|
||||||
defer putFloats(row)
|
defer putFloats(row)
|
||||||
for r := 0; r < ar; r++ {
|
for r := 0; r < ar; r++ {
|
||||||
@@ -699,6 +709,7 @@ func (m *Dense) Scale(f float64, a Matrix) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(a)
|
||||||
for r := 0; r < ar; r++ {
|
for r := 0; r < ar; r++ {
|
||||||
for c := 0; c < ac; c++ {
|
for c := 0; c < ac; c++ {
|
||||||
m.set(r, c, f*a.At(r, c))
|
m.set(r, c, f*a.At(r, c))
|
||||||
@@ -738,6 +749,7 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.checkOverlapMatrix(a)
|
||||||
for r := 0; r < ar; r++ {
|
for r := 0; r < ar; r++ {
|
||||||
for c := 0; c < ac; c++ {
|
for c := 0; c < ac; c++ {
|
||||||
m.set(r, c, fn(r, c, a.At(r, c)))
|
m.set(r, c, fn(r, c, a.At(r, c)))
|
||||||
@@ -845,6 +857,7 @@ func (m *Dense) Outer(alpha float64, x, y Vector) {
|
|||||||
if rv, ok := xU.(RawVectorer); ok {
|
if rv, ok := xU.(RawVectorer); ok {
|
||||||
xmat = rv.RawVector()
|
xmat = rv.RawVector()
|
||||||
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
fast = false
|
fast = false
|
||||||
}
|
}
|
||||||
|
@@ -70,10 +70,46 @@ func (m *Dense) checkOverlap(a blas64.General) bool {
|
|||||||
return checkOverlap(m.RawMatrix(), a)
|
return checkOverlap(m.RawMatrix(), a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Dense) checkOverlapMatrix(a Matrix) bool {
|
||||||
|
if m == a {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var amat blas64.General
|
||||||
|
switch a := a.(type) {
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
case RawMatrixer:
|
||||||
|
amat = a.RawMatrix()
|
||||||
|
case RawSymmetricer:
|
||||||
|
amat = generalFromSymmetric(a.RawSymmetric())
|
||||||
|
case RawTriangular:
|
||||||
|
amat = generalFromTriangular(a.RawTriangular())
|
||||||
|
}
|
||||||
|
return m.checkOverlap(amat)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SymDense) checkOverlap(a blas64.General) bool {
|
func (s *SymDense) checkOverlap(a blas64.General) bool {
|
||||||
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
|
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
|
||||||
|
if s == a {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var amat blas64.General
|
||||||
|
switch a := a.(type) {
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
case RawMatrixer:
|
||||||
|
amat = a.RawMatrix()
|
||||||
|
case RawSymmetricer:
|
||||||
|
amat = generalFromSymmetric(a.RawSymmetric())
|
||||||
|
case RawTriangular:
|
||||||
|
amat = generalFromTriangular(a.RawTriangular())
|
||||||
|
}
|
||||||
|
return s.checkOverlap(amat)
|
||||||
|
}
|
||||||
|
|
||||||
// generalFromSymmetric returns a blas64.General with the backing
|
// generalFromSymmetric returns a blas64.General with the backing
|
||||||
// data and dimensions of a.
|
// data and dimensions of a.
|
||||||
func generalFromSymmetric(a blas64.Symmetric) blas64.General {
|
func generalFromSymmetric(a blas64.Symmetric) blas64.General {
|
||||||
@@ -89,6 +125,24 @@ func (t *TriDense) checkOverlap(a blas64.General) bool {
|
|||||||
return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
|
return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
|
||||||
|
if t == a {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var amat blas64.General
|
||||||
|
switch a := a.(type) {
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
case RawMatrixer:
|
||||||
|
amat = a.RawMatrix()
|
||||||
|
case RawSymmetricer:
|
||||||
|
amat = generalFromSymmetric(a.RawSymmetric())
|
||||||
|
case RawTriangular:
|
||||||
|
amat = generalFromTriangular(a.RawTriangular())
|
||||||
|
}
|
||||||
|
return t.checkOverlap(amat)
|
||||||
|
}
|
||||||
|
|
||||||
// generalFromTriangular returns a blas64.General with the backing
|
// generalFromTriangular returns a blas64.General with the backing
|
||||||
// data and dimensions of a.
|
// data and dimensions of a.
|
||||||
func generalFromTriangular(a blas64.Triangular) blas64.General {
|
func generalFromTriangular(a blas64.Triangular) blas64.General {
|
||||||
|
@@ -201,6 +201,8 @@ func (s *SymDense) AddSym(a, b Symmetric) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.checkOverlapMatrix(a)
|
||||||
|
s.checkOverlapMatrix(b)
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
|
stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
|
||||||
for j := i; j < n; j++ {
|
for j := i; j < n; j++ {
|
||||||
|
@@ -346,9 +346,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) {
|
|||||||
// Note that matrix inversion is numerically unstable, and should generally be
|
// Note that matrix inversion is numerically unstable, and should generally be
|
||||||
// avoided where possible, for example by using the Solve routines.
|
// avoided where possible, for example by using the Solve routines.
|
||||||
func (t *TriDense) InverseTri(a Triangular) error {
|
func (t *TriDense) InverseTri(a Triangular) error {
|
||||||
if rt, ok := a.(RawTriangular); ok {
|
t.checkOverlapMatrix(a)
|
||||||
t.checkOverlap(generalFromTriangular(rt.RawTriangular()))
|
|
||||||
}
|
|
||||||
n, _ := a.Triangle()
|
n, _ := a.Triangle()
|
||||||
t.reuseAs(a.Triangle())
|
t.reuseAs(a.Triangle())
|
||||||
t.Copy(a)
|
t.Copy(a)
|
||||||
@@ -385,6 +383,8 @@ func (t *TriDense) MulTri(a, b Triangular) {
|
|||||||
|
|
||||||
aU, _ := untransposeTri(a)
|
aU, _ := untransposeTri(a)
|
||||||
bU, _ := untransposeTri(b)
|
bU, _ := untransposeTri(b)
|
||||||
|
t.checkOverlapMatrix(bU)
|
||||||
|
t.checkOverlapMatrix(aU)
|
||||||
t.reuseAs(n, kind)
|
t.reuseAs(n, kind)
|
||||||
var restore func()
|
var restore func()
|
||||||
if t == aU {
|
if t == aU {
|
||||||
@@ -430,6 +430,9 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
|||||||
switch a := a.(type) {
|
switch a := a.(type) {
|
||||||
case RawTriangular:
|
case RawTriangular:
|
||||||
amat := a.RawTriangular()
|
amat := a.RawTriangular()
|
||||||
|
if t != a {
|
||||||
|
t.checkOverlap(generalFromTriangular(amat))
|
||||||
|
}
|
||||||
if kind == Upper {
|
if kind == Upper {
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
|
ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
|
||||||
@@ -449,6 +452,7 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
|
t.checkOverlapMatrix(a)
|
||||||
isUpper := kind == Upper
|
isUpper := kind == Upper
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if isUpper {
|
if isUpper {
|
||||||
|
Reference in New Issue
Block a user