mat: allow cross-raw shadow detection

This commit is contained in:
kortschak
2017-12-07 10:04:23 +10:30
committed by Dan Kortschak
parent b525d7913d
commit 835cce7bd0
4 changed files with 20 additions and 17 deletions

View File

@@ -164,8 +164,6 @@
// the value has been untransposed if necessary. // the value has been untransposed if necessary.
// //
// mat will not attempt to detect element overlap if the input does not implement a // mat will not attempt to detect element overlap if the input does not implement a
// Raw method, or if the Raw method differs from that of the receiver except when a // Raw method. Method behavior is undefined if there is undetected overlap.
// conversion has occurred through a mat API function. Method behavior is undefined
// if there is undetected overlap.
// //
package mat // import "gonum.org/v1/gonum/mat" package mat // import "gonum.org/v1/gonum/mat"

View File

@@ -70,8 +70,8 @@ func (m *Dense) checkOverlap(a blas64.General) bool {
return checkOverlap(m.RawMatrix(), a) return checkOverlap(m.RawMatrix(), a)
} }
func (s *SymDense) checkOverlap(a blas64.Symmetric) bool { func (s *SymDense) checkOverlap(a blas64.General) bool {
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), generalFromSymmetric(a)) return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
} }
// generalFromSymmetric returns a blas64.General with the backing // generalFromSymmetric returns a blas64.General with the backing
@@ -85,8 +85,8 @@ func generalFromSymmetric(a blas64.Symmetric) blas64.General {
} }
} }
func (t *TriDense) checkOverlap(a blas64.Triangular) bool { func (t *TriDense) checkOverlap(a blas64.General) bool {
return checkOverlap(generalFromTriangular(t.RawTriangular()), generalFromTriangular(a)) return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
} }
// generalFromTriangular returns a blas64.General with the backing // generalFromTriangular returns a blas64.General with the backing

View File

@@ -179,10 +179,10 @@ func (s *SymDense) AddSym(a, b Symmetric) {
if b, ok := b.(RawSymmetricer); ok { if b, ok := b.(RawSymmetricer); ok {
amat, bmat := a.RawSymmetric(), b.RawSymmetric() amat, bmat := a.RawSymmetric(), b.RawSymmetric()
if s != a { if s != a {
s.checkOverlap(amat) s.checkOverlap(generalFromSymmetric(amat))
} }
if s != b { if s != b {
s.checkOverlap(bmat) s.checkOverlap(generalFromSymmetric(bmat))
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n] btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
@@ -240,7 +240,7 @@ func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x *VecDense) {
s.reuseAs(n) s.reuseAs(n)
if s != a { if s != a {
if rs, ok := a.(RawSymmetricer); ok { if rs, ok := a.(RawSymmetricer); ok {
s.checkOverlap(rs.RawSymmetric()) s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
} }
s.CopySym(a) s.CopySym(a)
} }
@@ -266,7 +266,7 @@ func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) {
} }
if a != s { if a != s {
if rs, ok := a.(RawSymmetricer); ok { if rs, ok := a.(RawSymmetricer); ok {
s.checkOverlap(rs.RawSymmetric()) s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
} }
s.reuseAs(n) s.reuseAs(n)
s.CopySym(a) s.CopySym(a)
@@ -304,8 +304,13 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
s.CopySym(w) s.CopySym(w)
putWorkspaceSym(w) putWorkspaceSym(w)
} else { } else {
if rs, ok := x.(RawSymmetricer); ok { switch r := x.(type) {
s.checkOverlap(rs.RawSymmetric()) case RawMatrixer:
s.checkOverlap(r.RawMatrix())
case RawSymmetricer:
s.checkOverlap(generalFromSymmetric(r.RawSymmetric()))
case RawTriangular:
s.checkOverlap(generalFromTriangular(r.RawTriangular()))
} }
// Only zero the upper triangle. // Only zero the upper triangle.
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -337,7 +342,7 @@ func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) {
w.reuseAs(n) w.reuseAs(n)
if s != a { if s != a {
if rs, ok := a.(RawSymmetricer); ok { if rs, ok := a.(RawSymmetricer); ok {
s.checkOverlap(rs.RawSymmetric()) s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
} }
w.CopySym(a) w.CopySym(a)
} }
@@ -352,7 +357,7 @@ func (s *SymDense) ScaleSym(f float64, a Symmetric) {
if a, ok := a.(RawSymmetricer); ok { if a, ok := a.(RawSymmetricer); ok {
amat := a.RawSymmetric() amat := a.RawSymmetric()
if s != a { if s != a {
s.checkOverlap(amat) s.checkOverlap(generalFromSymmetric(amat))
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
for j := i; j < n; j++ { for j := i; j < n; j++ {
@@ -386,7 +391,7 @@ func (s *SymDense) SubsetSym(a Symmetric, set []int) {
if a, ok := a.(RawSymmetricer); ok { if a, ok := a.(RawSymmetricer); ok {
raw := a.RawSymmetric() raw := a.RawSymmetric()
if s != a { if s != a {
s.checkOverlap(raw) s.checkOverlap(generalFromSymmetric(raw))
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n] ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]

View File

@@ -344,7 +344,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) {
// avoided where possible, for example by using the Solve routines. // avoided where possible, for example by using the Solve routines.
func (t *TriDense) InverseTri(a Triangular) error { func (t *TriDense) InverseTri(a Triangular) error {
if rt, ok := a.(RawTriangular); ok { if rt, ok := a.(RawTriangular); ok {
t.checkOverlap(rt.RawTriangular()) t.checkOverlap(generalFromTriangular(rt.RawTriangular()))
} }
n, _ := a.Triangle() n, _ := a.Triangle()
t.reuseAs(a.Triangle()) t.reuseAs(a.Triangle())