diff --git a/mat/shadow.go b/mat/shadow.go index c3f683e5..8a941c7f 100644 --- a/mat/shadow.go +++ b/mat/shadow.go @@ -67,6 +67,8 @@ func (m *Dense) checkOverlapMatrix(a Matrix) bool { amat = ar.RawMatrix() case RawSymmetricer: amat = generalFromSymmetric(ar.RawSymmetric()) + case RawSymBander: + amat = generalFromSymmetricBand(ar.RawSymBand()) case RawTriangular: amat = generalFromTriangular(ar.RawTriangular()) case RawVectorer: @@ -92,6 +94,8 @@ func (s *SymDense) checkOverlapMatrix(a Matrix) bool { amat = ar.RawMatrix() case RawSymmetricer: amat = generalFromSymmetric(ar.RawSymmetric()) + case RawSymBander: + amat = generalFromSymmetricBand(ar.RawSymBand()) case RawTriangular: amat = generalFromTriangular(ar.RawTriangular()) case RawVectorer: @@ -128,6 +132,8 @@ func (t *TriDense) checkOverlapMatrix(a Matrix) bool { amat = ar.RawMatrix() case RawSymmetricer: amat = generalFromSymmetric(ar.RawSymmetric()) + case RawSymBander: + amat = generalFromSymmetricBand(ar.RawSymBand()) case RawTriangular: amat = generalFromTriangular(ar.RawTriangular()) case RawVectorer: @@ -196,3 +202,41 @@ func generalFromVector(a blas64.Vector, r, c int) blas64.General { Data: a.Data, } } + +func (s *SymBandDense) checkOverlap(a blas64.General) bool { + return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a) +} + +func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool { + if s == a { + return false + } + var amat blas64.General + switch ar := a.(type) { + default: + return false + case RawMatrixer: + amat = ar.RawMatrix() + case RawSymmetricer: + amat = generalFromSymmetric(ar.RawSymmetric()) + case RawSymBander: + amat = generalFromSymmetricBand(ar.RawSymBand()) + case RawTriangular: + amat = generalFromTriangular(ar.RawTriangular()) + case RawVectorer: + r, c := a.Dims() + amat = generalFromVector(ar.RawVector(), r, c) + } + return s.checkOverlap(amat) +} + +// generalFromSymmetricBand returns a blas64.General with the backing +// data and dimensions of a. +func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General { + return blas64.General{ + Rows: a.N, + Cols: a.K + 1, + Data: a.Data, + Stride: a.Stride, + } +} diff --git a/mat/vector.go b/mat/vector.go index 5e31f978..0f0f3cbf 100644 --- a/mat/vector.go +++ b/mat/vector.go @@ -597,6 +597,7 @@ func (v *VecDense) MulVec(a Matrix, b Vector) { return case *SymBandDense: if fast { + aU.checkOverlap(v.asGeneral()) blas64.Sbmv(1, aU.mat, bmat, 0, v.mat) return }