mat: add overlap detection for SymBandDense

This commit is contained in:
Vladimir Chalupecky
2019-09-26 11:52:49 +02:00
committed by Vladimír Chalupecký
parent bd50f5876c
commit 94b2bbd8ac
2 changed files with 45 additions and 0 deletions

View File

@@ -67,6 +67,8 @@ func (m *Dense) checkOverlapMatrix(a Matrix) bool {
amat = ar.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(ar.RawSymmetric())
case RawSymBander:
amat = generalFromSymmetricBand(ar.RawSymBand())
case RawTriangular:
amat = generalFromTriangular(ar.RawTriangular())
case RawVectorer:
@@ -92,6 +94,8 @@ func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
amat = ar.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(ar.RawSymmetric())
case RawSymBander:
amat = generalFromSymmetricBand(ar.RawSymBand())
case RawTriangular:
amat = generalFromTriangular(ar.RawTriangular())
case RawVectorer:
@@ -128,6 +132,8 @@ func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
amat = ar.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(ar.RawSymmetric())
case RawSymBander:
amat = generalFromSymmetricBand(ar.RawSymBand())
case RawTriangular:
amat = generalFromTriangular(ar.RawTriangular())
case RawVectorer:
@@ -196,3 +202,41 @@ func generalFromVector(a blas64.Vector, r, c int) blas64.General {
Data: a.Data,
}
}
func (s *SymBandDense) checkOverlap(a blas64.General) bool {
return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a)
}
func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool {
if s == a {
return false
}
var amat blas64.General
switch ar := a.(type) {
default:
return false
case RawMatrixer:
amat = ar.RawMatrix()
case RawSymmetricer:
amat = generalFromSymmetric(ar.RawSymmetric())
case RawSymBander:
amat = generalFromSymmetricBand(ar.RawSymBand())
case RawTriangular:
amat = generalFromTriangular(ar.RawTriangular())
case RawVectorer:
r, c := a.Dims()
amat = generalFromVector(ar.RawVector(), r, c)
}
return s.checkOverlap(amat)
}
// generalFromSymmetricBand returns a blas64.General with the backing
// data and dimensions of a.
func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General {
return blas64.General{
Rows: a.N,
Cols: a.K + 1,
Data: a.Data,
Stride: a.Stride,
}
}

View File

@@ -597,6 +597,7 @@ func (v *VecDense) MulVec(a Matrix, b Vector) {
return
case *SymBandDense:
if fast {
aU.checkOverlap(v.asGeneral())
blas64.Sbmv(1, aU.mat, bmat, 0, v.mat)
return
}