diff --git a/mat/dense_arithmetic.go b/mat/dense_arithmetic.go index dd4526f6..5bbf1058 100644 --- a/mat/dense_arithmetic.go +++ b/mat/dense_arithmetic.go @@ -300,20 +300,21 @@ func (m *Dense) Mul(a, b Matrix) { // temporary memory. // C = A^T * B = (B^T * A)^T // C^T = B^T * A. - if aUrm, ok := aU.(RawMatrixer); ok { - amat := aUrm.RawMatrix() + if aU, ok := aU.(RawMatrixer); ok { + amat := aU.RawMatrix() if restore == nil { m.checkOverlap(amat) } - if bUrm, ok := bU.(RawMatrixer); ok { - bmat := bUrm.RawMatrix() + switch bU := bU.(type) { + case RawMatrixer: + bmat := bU.RawMatrix() if restore == nil { m.checkOverlap(bmat) } blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat) return - } - if bU, ok := bU.(RawSymmetricer); ok { + + case RawSymmetricer: bmat := bU.RawSymmetric() if aTrans { c := getWorkspace(ac, ar, false) @@ -324,8 +325,8 @@ func (m *Dense) Mul(a, b Matrix) { } blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat) return - } - if bU, ok := bU.(RawTriangular); ok { + + case RawTriangular: // Trmm updates in place, so copy aU first. bmat := bU.RawTriangular() if aTrans { @@ -345,8 +346,8 @@ func (m *Dense) Mul(a, b Matrix) { m.Copy(a) blas64.Trmm(blas.Right, bT, 1, bmat, m.mat) return - } - if bU, ok := bU.(*VecDense); ok { + + case *VecDense: m.checkOverlap(bU.asGeneral()) bvec := bU.RawVector() if bTrans { @@ -369,12 +370,13 @@ func (m *Dense) Mul(a, b Matrix) { return } } - if bUrm, ok := bU.(RawMatrixer); ok { - bmat := bUrm.RawMatrix() + if bU, ok := bU.(RawMatrixer); ok { + bmat := bU.RawMatrix() if restore == nil { m.checkOverlap(bmat) } - if aU, ok := aU.(RawSymmetricer); ok { + switch aU := aU.(type) { + case RawSymmetricer: amat := aU.RawSymmetric() if bTrans { c := getWorkspace(bc, br, false) @@ -385,8 +387,8 @@ func (m *Dense) Mul(a, b Matrix) { } blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat) return - } - if aU, ok := aU.(RawTriangular); ok { + + case RawTriangular: // Trmm updates in place, so copy bU first. amat := aU.RawTriangular() if bTrans { @@ -406,8 +408,8 @@ func (m *Dense) Mul(a, b Matrix) { m.Copy(b) blas64.Trmm(blas.Left, aT, 1, amat, m.mat) return - } - if aU, ok := aU.(*VecDense); ok { + + case *VecDense: m.checkOverlap(aU.asGeneral()) avec := aU.RawVector() if aTrans {