mat: replace sequential conditional assertions with type switch

This commit is contained in:
Dan Kortschak
2019-05-17 16:13:41 +09:30
parent 047f2c2add
commit a17cbc57c5

View File

@@ -300,20 +300,21 @@ func (m *Dense) Mul(a, b Matrix) {
// temporary memory.
// C = A^T * B = (B^T * A)^T
// C^T = B^T * A.
if aUrm, ok := aU.(RawMatrixer); ok {
amat := aUrm.RawMatrix()
if aU, ok := aU.(RawMatrixer); ok {
amat := aU.RawMatrix()
if restore == nil {
m.checkOverlap(amat)
}
if bUrm, ok := bU.(RawMatrixer); ok {
bmat := bUrm.RawMatrix()
switch bU := bU.(type) {
case RawMatrixer:
bmat := bU.RawMatrix()
if restore == nil {
m.checkOverlap(bmat)
}
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
return
}
if bU, ok := bU.(RawSymmetricer); ok {
case RawSymmetricer:
bmat := bU.RawSymmetric()
if aTrans {
c := getWorkspace(ac, ar, false)
@@ -324,8 +325,8 @@ func (m *Dense) Mul(a, b Matrix) {
}
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
return
}
if bU, ok := bU.(RawTriangular); ok {
case RawTriangular:
// Trmm updates in place, so copy aU first.
bmat := bU.RawTriangular()
if aTrans {
@@ -345,8 +346,8 @@ func (m *Dense) Mul(a, b Matrix) {
m.Copy(a)
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
return
}
if bU, ok := bU.(*VecDense); ok {
case *VecDense:
m.checkOverlap(bU.asGeneral())
bvec := bU.RawVector()
if bTrans {
@@ -369,12 +370,13 @@ func (m *Dense) Mul(a, b Matrix) {
return
}
}
if bUrm, ok := bU.(RawMatrixer); ok {
bmat := bUrm.RawMatrix()
if bU, ok := bU.(RawMatrixer); ok {
bmat := bU.RawMatrix()
if restore == nil {
m.checkOverlap(bmat)
}
if aU, ok := aU.(RawSymmetricer); ok {
switch aU := aU.(type) {
case RawSymmetricer:
amat := aU.RawSymmetric()
if bTrans {
c := getWorkspace(bc, br, false)
@@ -385,8 +387,8 @@ func (m *Dense) Mul(a, b Matrix) {
}
blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
return
}
if aU, ok := aU.(RawTriangular); ok {
case RawTriangular:
// Trmm updates in place, so copy bU first.
amat := aU.RawTriangular()
if bTrans {
@@ -406,8 +408,8 @@ func (m *Dense) Mul(a, b Matrix) {
m.Copy(b)
blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
return
}
if aU, ok := aU.(*VecDense); ok {
case *VecDense:
m.checkOverlap(aU.asGeneral())
avec := aU.RawVector()
if aTrans {