mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
mat: replace sequential conditional assertions with type switch
This commit is contained in:
@@ -300,20 +300,21 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
// temporary memory.
|
// temporary memory.
|
||||||
// C = A^T * B = (B^T * A)^T
|
// C = A^T * B = (B^T * A)^T
|
||||||
// C^T = B^T * A.
|
// C^T = B^T * A.
|
||||||
if aUrm, ok := aU.(RawMatrixer); ok {
|
if aU, ok := aU.(RawMatrixer); ok {
|
||||||
amat := aUrm.RawMatrix()
|
amat := aU.RawMatrix()
|
||||||
if restore == nil {
|
if restore == nil {
|
||||||
m.checkOverlap(amat)
|
m.checkOverlap(amat)
|
||||||
}
|
}
|
||||||
if bUrm, ok := bU.(RawMatrixer); ok {
|
switch bU := bU.(type) {
|
||||||
bmat := bUrm.RawMatrix()
|
case RawMatrixer:
|
||||||
|
bmat := bU.RawMatrix()
|
||||||
if restore == nil {
|
if restore == nil {
|
||||||
m.checkOverlap(bmat)
|
m.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
|
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
if bU, ok := bU.(RawSymmetricer); ok {
|
case RawSymmetricer:
|
||||||
bmat := bU.RawSymmetric()
|
bmat := bU.RawSymmetric()
|
||||||
if aTrans {
|
if aTrans {
|
||||||
c := getWorkspace(ac, ar, false)
|
c := getWorkspace(ac, ar, false)
|
||||||
@@ -324,8 +325,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
|
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
if bU, ok := bU.(RawTriangular); ok {
|
case RawTriangular:
|
||||||
// Trmm updates in place, so copy aU first.
|
// Trmm updates in place, so copy aU first.
|
||||||
bmat := bU.RawTriangular()
|
bmat := bU.RawTriangular()
|
||||||
if aTrans {
|
if aTrans {
|
||||||
@@ -345,8 +346,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
m.Copy(a)
|
m.Copy(a)
|
||||||
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
|
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
if bU, ok := bU.(*VecDense); ok {
|
case *VecDense:
|
||||||
m.checkOverlap(bU.asGeneral())
|
m.checkOverlap(bU.asGeneral())
|
||||||
bvec := bU.RawVector()
|
bvec := bU.RawVector()
|
||||||
if bTrans {
|
if bTrans {
|
||||||
@@ -369,12 +370,13 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if bUrm, ok := bU.(RawMatrixer); ok {
|
if bU, ok := bU.(RawMatrixer); ok {
|
||||||
bmat := bUrm.RawMatrix()
|
bmat := bU.RawMatrix()
|
||||||
if restore == nil {
|
if restore == nil {
|
||||||
m.checkOverlap(bmat)
|
m.checkOverlap(bmat)
|
||||||
}
|
}
|
||||||
if aU, ok := aU.(RawSymmetricer); ok {
|
switch aU := aU.(type) {
|
||||||
|
case RawSymmetricer:
|
||||||
amat := aU.RawSymmetric()
|
amat := aU.RawSymmetric()
|
||||||
if bTrans {
|
if bTrans {
|
||||||
c := getWorkspace(bc, br, false)
|
c := getWorkspace(bc, br, false)
|
||||||
@@ -385,8 +387,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
}
|
}
|
||||||
blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
|
blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
if aU, ok := aU.(RawTriangular); ok {
|
case RawTriangular:
|
||||||
// Trmm updates in place, so copy bU first.
|
// Trmm updates in place, so copy bU first.
|
||||||
amat := aU.RawTriangular()
|
amat := aU.RawTriangular()
|
||||||
if bTrans {
|
if bTrans {
|
||||||
@@ -406,8 +408,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
|||||||
m.Copy(b)
|
m.Copy(b)
|
||||||
blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
|
blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
|
||||||
return
|
return
|
||||||
}
|
|
||||||
if aU, ok := aU.(*VecDense); ok {
|
case *VecDense:
|
||||||
m.checkOverlap(aU.asGeneral())
|
m.checkOverlap(aU.asGeneral())
|
||||||
avec := aU.RawVector()
|
avec := aU.RawVector()
|
||||||
if aTrans {
|
if aTrans {
|
||||||
|
Reference in New Issue
Block a user