diff --git a/mat/doc.go b/mat/doc.go index 129444f9..2cc91001 100644 --- a/mat/doc.go +++ b/mat/doc.go @@ -164,8 +164,6 @@ // the value has been untransposed if necessary. // // mat will not attempt to detect element overlap if the input does not implement a -// Raw method, or if the Raw method differs from that of the receiver except when a -// conversion has occurred through a mat API function. Method behavior is undefined -// if there is undetected overlap. +// Raw method. Method behavior is undefined if there is undetected overlap. // package mat // import "gonum.org/v1/gonum/mat" diff --git a/mat/shadow.go b/mat/shadow.go index 5749241a..bcc55d6a 100644 --- a/mat/shadow.go +++ b/mat/shadow.go @@ -70,8 +70,8 @@ func (m *Dense) checkOverlap(a blas64.General) bool { return checkOverlap(m.RawMatrix(), a) } -func (s *SymDense) checkOverlap(a blas64.Symmetric) bool { - return checkOverlap(generalFromSymmetric(s.RawSymmetric()), generalFromSymmetric(a)) +func (s *SymDense) checkOverlap(a blas64.General) bool { + return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) } // generalFromSymmetric returns a blas64.General with the backing @@ -85,8 +85,8 @@ func generalFromSymmetric(a blas64.Symmetric) blas64.General { } } -func (t *TriDense) checkOverlap(a blas64.Triangular) bool { - return checkOverlap(generalFromTriangular(t.RawTriangular()), generalFromTriangular(a)) +func (t *TriDense) checkOverlap(a blas64.General) bool { + return checkOverlap(generalFromTriangular(t.RawTriangular()), a) } // generalFromTriangular returns a blas64.General with the backing diff --git a/mat/symmetric.go b/mat/symmetric.go index 873f5f34..9ef2b3ed 100644 --- a/mat/symmetric.go +++ b/mat/symmetric.go @@ -179,10 +179,10 @@ func (s *SymDense) AddSym(a, b Symmetric) { if b, ok := b.(RawSymmetricer); ok { amat, bmat := a.RawSymmetric(), b.RawSymmetric() if s != a { - s.checkOverlap(amat) + s.checkOverlap(generalFromSymmetric(amat)) } if s != b { - s.checkOverlap(bmat) + s.checkOverlap(generalFromSymmetric(bmat)) } for i := 0; i < n; i++ { btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n] @@ -240,7 +240,7 @@ func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x *VecDense) { s.reuseAs(n) if s != a { if rs, ok := a.(RawSymmetricer); ok { - s.checkOverlap(rs.RawSymmetric()) + s.checkOverlap(generalFromSymmetric(rs.RawSymmetric())) } s.CopySym(a) } @@ -266,7 +266,7 @@ func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) { } if a != s { if rs, ok := a.(RawSymmetricer); ok { - s.checkOverlap(rs.RawSymmetric()) + s.checkOverlap(generalFromSymmetric(rs.RawSymmetric())) } s.reuseAs(n) s.CopySym(a) @@ -304,8 +304,13 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) { s.CopySym(w) putWorkspaceSym(w) } else { - if rs, ok := x.(RawSymmetricer); ok { - s.checkOverlap(rs.RawSymmetric()) + switch r := x.(type) { + case RawMatrixer: + s.checkOverlap(r.RawMatrix()) + case RawSymmetricer: + s.checkOverlap(generalFromSymmetric(r.RawSymmetric())) + case RawTriangular: + s.checkOverlap(generalFromTriangular(r.RawTriangular())) } // Only zero the upper triangle. for i := 0; i < n; i++ { @@ -337,7 +342,7 @@ func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) { w.reuseAs(n) if s != a { if rs, ok := a.(RawSymmetricer); ok { - s.checkOverlap(rs.RawSymmetric()) + s.checkOverlap(generalFromSymmetric(rs.RawSymmetric())) } w.CopySym(a) } @@ -352,7 +357,7 @@ func (s *SymDense) ScaleSym(f float64, a Symmetric) { if a, ok := a.(RawSymmetricer); ok { amat := a.RawSymmetric() if s != a { - s.checkOverlap(amat) + s.checkOverlap(generalFromSymmetric(amat)) } for i := 0; i < n; i++ { for j := i; j < n; j++ { @@ -386,7 +391,7 @@ func (s *SymDense) SubsetSym(a Symmetric, set []int) { if a, ok := a.(RawSymmetricer); ok { raw := a.RawSymmetric() if s != a { - s.checkOverlap(raw) + s.checkOverlap(generalFromSymmetric(raw)) } for i := 0; i < n; i++ { ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n] diff --git a/mat/triangular.go b/mat/triangular.go index cb38d5e5..c49e0078 100644 --- a/mat/triangular.go +++ b/mat/triangular.go @@ -344,7 +344,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) { // avoided where possible, for example by using the Solve routines. func (t *TriDense) InverseTri(a Triangular) error { if rt, ok := a.(RawTriangular); ok { - t.checkOverlap(rt.RawTriangular()) + t.checkOverlap(generalFromTriangular(rt.RawTriangular())) } n, _ := a.Triangle() t.reuseAs(a.Triangle())