mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: add overlap detection for SymBandDense
This commit is contained in:

committed by
Vladimír Chalupecký

parent
bd50f5876c
commit
94b2bbd8ac
@@ -67,6 +67,8 @@ func (m *Dense) checkOverlapMatrix(a Matrix) bool {
|
|||||||
amat = ar.RawMatrix()
|
amat = ar.RawMatrix()
|
||||||
case RawSymmetricer:
|
case RawSymmetricer:
|
||||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||||
|
case RawSymBander:
|
||||||
|
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||||
case RawTriangular:
|
case RawTriangular:
|
||||||
amat = generalFromTriangular(ar.RawTriangular())
|
amat = generalFromTriangular(ar.RawTriangular())
|
||||||
case RawVectorer:
|
case RawVectorer:
|
||||||
@@ -92,6 +94,8 @@ func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
|
|||||||
amat = ar.RawMatrix()
|
amat = ar.RawMatrix()
|
||||||
case RawSymmetricer:
|
case RawSymmetricer:
|
||||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||||
|
case RawSymBander:
|
||||||
|
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||||
case RawTriangular:
|
case RawTriangular:
|
||||||
amat = generalFromTriangular(ar.RawTriangular())
|
amat = generalFromTriangular(ar.RawTriangular())
|
||||||
case RawVectorer:
|
case RawVectorer:
|
||||||
@@ -128,6 +132,8 @@ func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
|
|||||||
amat = ar.RawMatrix()
|
amat = ar.RawMatrix()
|
||||||
case RawSymmetricer:
|
case RawSymmetricer:
|
||||||
amat = generalFromSymmetric(ar.RawSymmetric())
|
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||||
|
case RawSymBander:
|
||||||
|
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||||
case RawTriangular:
|
case RawTriangular:
|
||||||
amat = generalFromTriangular(ar.RawTriangular())
|
amat = generalFromTriangular(ar.RawTriangular())
|
||||||
case RawVectorer:
|
case RawVectorer:
|
||||||
@@ -196,3 +202,41 @@ func generalFromVector(a blas64.Vector, r, c int) blas64.General {
|
|||||||
Data: a.Data,
|
Data: a.Data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SymBandDense) checkOverlap(a blas64.General) bool {
|
||||||
|
return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool {
|
||||||
|
if s == a {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
var amat blas64.General
|
||||||
|
switch ar := a.(type) {
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
case RawMatrixer:
|
||||||
|
amat = ar.RawMatrix()
|
||||||
|
case RawSymmetricer:
|
||||||
|
amat = generalFromSymmetric(ar.RawSymmetric())
|
||||||
|
case RawSymBander:
|
||||||
|
amat = generalFromSymmetricBand(ar.RawSymBand())
|
||||||
|
case RawTriangular:
|
||||||
|
amat = generalFromTriangular(ar.RawTriangular())
|
||||||
|
case RawVectorer:
|
||||||
|
r, c := a.Dims()
|
||||||
|
amat = generalFromVector(ar.RawVector(), r, c)
|
||||||
|
}
|
||||||
|
return s.checkOverlap(amat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generalFromSymmetricBand returns a blas64.General with the backing
|
||||||
|
// data and dimensions of a.
|
||||||
|
func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General {
|
||||||
|
return blas64.General{
|
||||||
|
Rows: a.N,
|
||||||
|
Cols: a.K + 1,
|
||||||
|
Data: a.Data,
|
||||||
|
Stride: a.Stride,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -597,6 +597,7 @@ func (v *VecDense) MulVec(a Matrix, b Vector) {
|
|||||||
return
|
return
|
||||||
case *SymBandDense:
|
case *SymBandDense:
|
||||||
if fast {
|
if fast {
|
||||||
|
aU.checkOverlap(v.asGeneral())
|
||||||
blas64.Sbmv(1, aU.mat, bmat, 0, v.mat)
|
blas64.Sbmv(1, aU.mat, bmat, 0, v.mat)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user