mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
mat: add overlap detection for SymBandDense
This commit is contained in:

committed by
Vladimír Chalupecký

parent
bd50f5876c
commit
94b2bbd8ac
@@ -67,6 +67,8 @@ func (m *Dense) checkOverlapMatrix(a Matrix) bool {
|
||||
amat = ar.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||
case RawSymBander:
|
||||
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(ar.RawTriangular())
|
||||
case RawVectorer:
|
||||
@@ -92,6 +94,8 @@ func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
|
||||
amat = ar.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||
case RawSymBander:
|
||||
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(ar.RawTriangular())
|
||||
case RawVectorer:
|
||||
@@ -128,6 +132,8 @@ func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
|
||||
amat = ar.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||
case RawSymBander:
|
||||
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(ar.RawTriangular())
|
||||
case RawVectorer:
|
||||
@@ -196,3 +202,41 @@ func generalFromVector(a blas64.Vector, r, c int) blas64.General {
|
||||
Data: a.Data,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SymBandDense) checkOverlap(a blas64.General) bool {
|
||||
return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a)
|
||||
}
|
||||
|
||||
func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool {
|
||||
if s == a {
|
||||
return false
|
||||
}
|
||||
var amat blas64.General
|
||||
switch ar := a.(type) {
|
||||
default:
|
||||
return false
|
||||
case RawMatrixer:
|
||||
amat = ar.RawMatrix()
|
||||
case RawSymmetricer:
|
||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||
case RawSymBander:
|
||||
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||
case RawTriangular:
|
||||
amat = generalFromTriangular(ar.RawTriangular())
|
||||
case RawVectorer:
|
||||
r, c := a.Dims()
|
||||
amat = generalFromVector(ar.RawVector(), r, c)
|
||||
}
|
||||
return s.checkOverlap(amat)
|
||||
}
|
||||
|
||||
// generalFromSymmetricBand returns a blas64.General with the backing
|
||||
// data and dimensions of a.
|
||||
func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General {
|
||||
return blas64.General{
|
||||
Rows: a.N,
|
||||
Cols: a.K + 1,
|
||||
Data: a.Data,
|
||||
Stride: a.Stride,
|
||||
}
|
||||
}
|
||||
|
@@ -597,6 +597,7 @@ func (v *VecDense) MulVec(a Matrix, b Vector) {
|
||||
return
|
||||
case *SymBandDense:
|
||||
if fast {
|
||||
aU.checkOverlap(v.asGeneral())
|
||||
blas64.Sbmv(1, aU.mat, bmat, 0, v.mat)
|
||||
return
|
||||
}
|
||||
|
Reference in New Issue
Block a user