diff --git a/mat/dense.go b/mat/dense.go index 4b060dc0..fbcac756 100644 --- a/mat/dense.go +++ b/mat/dense.go @@ -485,6 +485,7 @@ func (m *Dense) Copy(a Matrix) (r, c int) { // Nothing to do. } default: + m.checkOverlapMatrix(aU) for i := 0; i < r; i++ { for j := 0; j < c; j++ { m.set(i, j, a.At(i, j)) diff --git a/mat/dense_arithmetic.go b/mat/dense_arithmetic.go index 4773ea43..e909a6a1 100644 --- a/mat/dense_arithmetic.go +++ b/mat/dense_arithmetic.go @@ -43,6 +43,8 @@ func (m *Dense) Add(a, b Matrix) { } } + m.checkOverlapMatrix(aU) + m.checkOverlapMatrix(bU) var restore func() if m == aU { m, restore = m.isolatedWorkspace(aU) @@ -90,6 +92,8 @@ func (m *Dense) Sub(a, b Matrix) { } } + m.checkOverlapMatrix(aU) + m.checkOverlapMatrix(bU) var restore func() if m == aU { m, restore = m.isolatedWorkspace(aU) @@ -138,6 +142,8 @@ func (m *Dense) MulElem(a, b Matrix) { } } + m.checkOverlapMatrix(aU) + m.checkOverlapMatrix(bU) var restore func() if m == aU { m, restore = m.isolatedWorkspace(aU) @@ -186,6 +192,8 @@ func (m *Dense) DivElem(a, b Matrix) { } } + m.checkOverlapMatrix(aU) + m.checkOverlapMatrix(bU) var restore func() if m == aU { m, restore = m.isolatedWorkspace(aU) @@ -429,6 +437,8 @@ func (m *Dense) Mul(a, b Matrix) { } } + m.checkOverlapMatrix(aU) + m.checkOverlapMatrix(bU) row := getFloats(ac, false) defer putFloats(row) for r := 0; r < ar; r++ { @@ -699,6 +709,7 @@ func (m *Dense) Scale(f float64, a Matrix) { return } + m.checkOverlapMatrix(a) for r := 0; r < ar; r++ { for c := 0; c < ac; c++ { m.set(r, c, f*a.At(r, c)) @@ -738,6 +749,7 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) { return } + m.checkOverlapMatrix(a) for r := 0; r < ar; r++ { for c := 0; c < ac; c++ { m.set(r, c, fn(r, c, a.At(r, c))) @@ -845,6 +857,7 @@ func (m *Dense) Outer(alpha float64, x, y Vector) { if rv, ok := xU.(RawVectorer); ok { xmat = rv.RawVector() m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral()) + } else { fast = false } diff --git a/mat/shadow.go b/mat/shadow.go index bcc55d6a..cc62e44f 100644 --- a/mat/shadow.go +++ b/mat/shadow.go @@ -70,10 +70,46 @@ func (m *Dense) checkOverlap(a blas64.General) bool { return checkOverlap(m.RawMatrix(), a) } +func (m *Dense) checkOverlapMatrix(a Matrix) bool { + if m == a { + return false + } + var amat blas64.General + switch a := a.(type) { + default: + return false + case RawMatrixer: + amat = a.RawMatrix() + case RawSymmetricer: + amat = generalFromSymmetric(a.RawSymmetric()) + case RawTriangular: + amat = generalFromTriangular(a.RawTriangular()) + } + return m.checkOverlap(amat) +} + func (s *SymDense) checkOverlap(a blas64.General) bool { return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) } +func (s *SymDense) checkOverlapMatrix(a Matrix) bool { + if s == a { + return false + } + var amat blas64.General + switch a := a.(type) { + default: + return false + case RawMatrixer: + amat = a.RawMatrix() + case RawSymmetricer: + amat = generalFromSymmetric(a.RawSymmetric()) + case RawTriangular: + amat = generalFromTriangular(a.RawTriangular()) + } + return s.checkOverlap(amat) +} + // generalFromSymmetric returns a blas64.General with the backing // data and dimensions of a. func generalFromSymmetric(a blas64.Symmetric) blas64.General { @@ -89,6 +125,24 @@ func (t *TriDense) checkOverlap(a blas64.General) bool { return checkOverlap(generalFromTriangular(t.RawTriangular()), a) } +func (t *TriDense) checkOverlapMatrix(a Matrix) bool { + if t == a { + return false + } + var amat blas64.General + switch a := a.(type) { + default: + return false + case RawMatrixer: + amat = a.RawMatrix() + case RawSymmetricer: + amat = generalFromSymmetric(a.RawSymmetric()) + case RawTriangular: + amat = generalFromTriangular(a.RawTriangular()) + } + return t.checkOverlap(amat) +} + // generalFromTriangular returns a blas64.General with the backing // data and dimensions of a. func generalFromTriangular(a blas64.Triangular) blas64.General { diff --git a/mat/symmetric.go b/mat/symmetric.go index 6d70bf77..d3ac2617 100644 --- a/mat/symmetric.go +++ b/mat/symmetric.go @@ -201,6 +201,8 @@ func (s *SymDense) AddSym(a, b Symmetric) { } } + s.checkOverlapMatrix(a) + s.checkOverlapMatrix(b) for i := 0; i < n; i++ { stmp := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n] for j := i; j < n; j++ { diff --git a/mat/triangular.go b/mat/triangular.go index 8a6b0ad5..c4446b55 100644 --- a/mat/triangular.go +++ b/mat/triangular.go @@ -346,9 +346,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) { // Note that matrix inversion is numerically unstable, and should generally be // avoided where possible, for example by using the Solve routines. func (t *TriDense) InverseTri(a Triangular) error { - if rt, ok := a.(RawTriangular); ok { - t.checkOverlap(generalFromTriangular(rt.RawTriangular())) - } + t.checkOverlapMatrix(a) n, _ := a.Triangle() t.reuseAs(a.Triangle()) t.Copy(a) @@ -385,6 +383,8 @@ func (t *TriDense) MulTri(a, b Triangular) { aU, _ := untransposeTri(a) bU, _ := untransposeTri(b) + t.checkOverlapMatrix(bU) + t.checkOverlapMatrix(aU) t.reuseAs(n, kind) var restore func() if t == aU { @@ -430,6 +430,9 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) { switch a := a.(type) { case RawTriangular: amat := a.RawTriangular() + if t != a { + t.checkOverlap(generalFromTriangular(amat)) + } if kind == Upper { for i := 0; i < n; i++ { ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n] @@ -449,6 +452,7 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) { } return default: + t.checkOverlapMatrix(a) isUpper := kind == Upper for i := 0; i < n; i++ { if isUpper {