diff --git a/mat/dense_arithmetic.go b/mat/dense_arithmetic.go index 7a4afc86..4fe8b661 100644 --- a/mat/dense_arithmetic.go +++ b/mat/dense_arithmetic.go @@ -276,8 +276,8 @@ func (m *Dense) Mul(a, b Matrix) { panic(ErrShape) } - aU, aTrans := untranspose(a) - bU, bTrans := untranspose(b) + aU, aTrans := untransposeExtract(a) + bU, bTrans := untransposeExtract(b) m.reuseAs(ar, bc) var restore func() if m == aU { @@ -300,51 +300,47 @@ func (m *Dense) Mul(a, b Matrix) { // temporary memory. // C = A^T * B = (B^T * A)^T // C^T = B^T * A. - if aU, ok := aU.(RawMatrixer); ok { - amat := aU.RawMatrix() + if aU, ok := aU.(*Dense); ok { if restore == nil { - m.checkOverlap(amat) + m.checkOverlap(aU.mat) } switch bU := bU.(type) { - case RawMatrixer: - bmat := bU.RawMatrix() + case *Dense: if restore == nil { - m.checkOverlap(bmat) + m.checkOverlap(bU.mat) } - blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat) + blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat) return - case RawSymmetricer: - bmat := bU.RawSymmetric() + case *SymDense: if aTrans { c := getWorkspace(ac, ar, false) - blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat) + blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat) strictCopy(m, c.T()) putWorkspace(c) return } - blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat) + blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat) return - case RawTriangular: + case *TriDense: // Trmm updates in place, so copy aU first. - bmat := bU.RawTriangular() if aTrans { c := getWorkspace(ac, ar, false) var tmp Dense - tmp.SetRawMatrix(amat) + tmp.SetRawMatrix(aU.mat) c.Copy(&tmp) bT := blas.Trans if bTrans { bT = blas.NoTrans } - blas64.Trmm(blas.Left, bT, 1, bmat, c.mat) + blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat) strictCopy(m, c.T()) putWorkspace(c) return } m.Copy(a) - blas64.Trmm(blas.Right, bT, 1, bmat, m.mat) + blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat) return case *VecDense: @@ -359,54 +355,51 @@ func (m *Dense) Mul(a, b Matrix) { Stride: bvec.Inc, Data: bvec.Data, } - blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat) + blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat) return } cvec := blas64.Vector{ Inc: m.mat.Stride, Data: m.mat.Data, } - blas64.Gemv(aT, 1, amat, bvec, 0, cvec) + blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec) return } } - if bU, ok := bU.(RawMatrixer); ok { - bmat := bU.RawMatrix() + if bU, ok := bU.(*Dense); ok { if restore == nil { - m.checkOverlap(bmat) + m.checkOverlap(bU.mat) } switch aU := aU.(type) { - case RawSymmetricer: - amat := aU.RawSymmetric() + case *SymDense: if bTrans { c := getWorkspace(bc, br, false) - blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat) + blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat) strictCopy(m, c.T()) putWorkspace(c) return } - blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat) + blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat) return - case RawTriangular: + case *TriDense: // Trmm updates in place, so copy bU first. - amat := aU.RawTriangular() if bTrans { c := getWorkspace(bc, br, false) var tmp Dense - tmp.SetRawMatrix(bmat) + tmp.SetRawMatrix(bU.mat) c.Copy(&tmp) aT := blas.Trans if aTrans { aT = blas.NoTrans } - blas64.Trmm(blas.Right, aT, 1, amat, c.mat) + blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat) strictCopy(m, c.T()) putWorkspace(c) return } m.Copy(b) - blas64.Trmm(blas.Left, aT, 1, amat, m.mat) + blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat) return case *VecDense: @@ -423,7 +416,7 @@ func (m *Dense) Mul(a, b Matrix) { if bTrans { bT = blas.NoTrans } - blas64.Gemv(bT, 1, bmat, avec, 0, cvec) + blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec) return } // {ar,1} x {1,bc} which is not a vector result. @@ -434,7 +427,7 @@ func (m *Dense) Mul(a, b Matrix) { Stride: avec.Inc, Data: avec.Data, } - blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat) + blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat) return } }