mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: allow cross-raw shadow detection
This commit is contained in:
@@ -164,8 +164,6 @@
|
||||
// the value has been untransposed if necessary.
|
||||
//
|
||||
// mat will not attempt to detect element overlap if the input does not implement a
|
||||
// Raw method, or if the Raw method differs from that of the receiver except when a
|
||||
// conversion has occurred through a mat API function. Method behavior is undefined
|
||||
// if there is undetected overlap.
|
||||
// Raw method. Method behavior is undefined if there is undetected overlap.
|
||||
//
|
||||
package mat // import "gonum.org/v1/gonum/mat"
|
||||
|
@@ -70,8 +70,8 @@ func (m *Dense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(m.RawMatrix(), a)
|
||||
}
|
||||
|
||||
func (s *SymDense) checkOverlap(a blas64.Symmetric) bool {
|
||||
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), generalFromSymmetric(a))
|
||||
func (s *SymDense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
|
||||
}
|
||||
|
||||
// generalFromSymmetric returns a blas64.General with the backing
|
||||
@@ -85,8 +85,8 @@ func generalFromSymmetric(a blas64.Symmetric) blas64.General {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TriDense) checkOverlap(a blas64.Triangular) bool {
|
||||
return checkOverlap(generalFromTriangular(t.RawTriangular()), generalFromTriangular(a))
|
||||
func (t *TriDense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
|
||||
}
|
||||
|
||||
// generalFromTriangular returns a blas64.General with the backing
|
||||
|
@@ -179,10 +179,10 @@ func (s *SymDense) AddSym(a, b Symmetric) {
|
||||
if b, ok := b.(RawSymmetricer); ok {
|
||||
amat, bmat := a.RawSymmetric(), b.RawSymmetric()
|
||||
if s != a {
|
||||
s.checkOverlap(amat)
|
||||
s.checkOverlap(generalFromSymmetric(amat))
|
||||
}
|
||||
if s != b {
|
||||
s.checkOverlap(bmat)
|
||||
s.checkOverlap(generalFromSymmetric(bmat))
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
btmp := bmat.Data[i*bmat.Stride+i : i*bmat.Stride+n]
|
||||
@@ -240,7 +240,7 @@ func (s *SymDense) SymRankOne(a Symmetric, alpha float64, x *VecDense) {
|
||||
s.reuseAs(n)
|
||||
if s != a {
|
||||
if rs, ok := a.(RawSymmetricer); ok {
|
||||
s.checkOverlap(rs.RawSymmetric())
|
||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||
}
|
||||
s.CopySym(a)
|
||||
}
|
||||
@@ -266,7 +266,7 @@ func (s *SymDense) SymRankK(a Symmetric, alpha float64, x Matrix) {
|
||||
}
|
||||
if a != s {
|
||||
if rs, ok := a.(RawSymmetricer); ok {
|
||||
s.checkOverlap(rs.RawSymmetric())
|
||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||
}
|
||||
s.reuseAs(n)
|
||||
s.CopySym(a)
|
||||
@@ -304,8 +304,13 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
|
||||
s.CopySym(w)
|
||||
putWorkspaceSym(w)
|
||||
} else {
|
||||
if rs, ok := x.(RawSymmetricer); ok {
|
||||
s.checkOverlap(rs.RawSymmetric())
|
||||
switch r := x.(type) {
|
||||
case RawMatrixer:
|
||||
s.checkOverlap(r.RawMatrix())
|
||||
case RawSymmetricer:
|
||||
s.checkOverlap(generalFromSymmetric(r.RawSymmetric()))
|
||||
case RawTriangular:
|
||||
s.checkOverlap(generalFromTriangular(r.RawTriangular()))
|
||||
}
|
||||
// Only zero the upper triangle.
|
||||
for i := 0; i < n; i++ {
|
||||
@@ -337,7 +342,7 @@ func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) {
|
||||
w.reuseAs(n)
|
||||
if s != a {
|
||||
if rs, ok := a.(RawSymmetricer); ok {
|
||||
s.checkOverlap(rs.RawSymmetric())
|
||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||
}
|
||||
w.CopySym(a)
|
||||
}
|
||||
@@ -352,7 +357,7 @@ func (s *SymDense) ScaleSym(f float64, a Symmetric) {
|
||||
if a, ok := a.(RawSymmetricer); ok {
|
||||
amat := a.RawSymmetric()
|
||||
if s != a {
|
||||
s.checkOverlap(amat)
|
||||
s.checkOverlap(generalFromSymmetric(amat))
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
for j := i; j < n; j++ {
|
||||
@@ -386,7 +391,7 @@ func (s *SymDense) SubsetSym(a Symmetric, set []int) {
|
||||
if a, ok := a.(RawSymmetricer); ok {
|
||||
raw := a.RawSymmetric()
|
||||
if s != a {
|
||||
s.checkOverlap(raw)
|
||||
s.checkOverlap(generalFromSymmetric(raw))
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
ssub := s.mat.Data[i*s.mat.Stride : i*s.mat.Stride+n]
|
||||
|
@@ -344,7 +344,7 @@ func (t *TriDense) Copy(a Matrix) (r, c int) {
|
||||
// avoided where possible, for example by using the Solve routines.
|
||||
func (t *TriDense) InverseTri(a Triangular) error {
|
||||
if rt, ok := a.(RawTriangular); ok {
|
||||
t.checkOverlap(rt.RawTriangular())
|
||||
t.checkOverlap(generalFromTriangular(rt.RawTriangular()))
|
||||
}
|
||||
n, _ := a.Triangle()
|
||||
t.reuseAs(a.Triangle())
|
||||
|
Reference in New Issue
Block a user