mat: Use untransposeExtract in Dense.Mul (#997)

This commit is contained in:
Brendan Tracey
2019-07-23 20:27:42 +01:00
committed by GitHub
parent 9cfd3e46f1
commit 2122b538b6

View File

@@ -276,8 +276,8 @@ func (m *Dense) Mul(a, b Matrix) {
panic(ErrShape)
}
aU, aTrans := untranspose(a)
bU, bTrans := untranspose(b)
aU, aTrans := untransposeExtract(a)
bU, bTrans := untransposeExtract(b)
m.reuseAs(ar, bc)
var restore func()
if m == aU {
@@ -300,51 +300,47 @@ func (m *Dense) Mul(a, b Matrix) {
// temporary memory.
// C = A^T * B = (B^T * A)^T
// C^T = B^T * A.
if aU, ok := aU.(RawMatrixer); ok {
amat := aU.RawMatrix()
if aU, ok := aU.(*Dense); ok {
if restore == nil {
m.checkOverlap(amat)
m.checkOverlap(aU.mat)
}
switch bU := bU.(type) {
case RawMatrixer:
bmat := bU.RawMatrix()
case *Dense:
if restore == nil {
m.checkOverlap(bmat)
m.checkOverlap(bU.mat)
}
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat)
return
case RawSymmetricer:
bmat := bU.RawSymmetric()
case *SymDense:
if aTrans {
c := getWorkspace(ac, ar, false)
blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat)
return
case RawTriangular:
case *TriDense:
// Trmm updates in place, so copy aU first.
bmat := bU.RawTriangular()
if aTrans {
c := getWorkspace(ac, ar, false)
var tmp Dense
tmp.SetRawMatrix(amat)
tmp.SetRawMatrix(aU.mat)
c.Copy(&tmp)
bT := blas.Trans
if bTrans {
bT = blas.NoTrans
}
blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
m.Copy(a)
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat)
return
case *VecDense:
@@ -359,54 +355,51 @@ func (m *Dense) Mul(a, b Matrix) {
Stride: bvec.Inc,
Data: bvec.Data,
}
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat)
return
}
cvec := blas64.Vector{
Inc: m.mat.Stride,
Data: m.mat.Data,
}
blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec)
return
}
}
if bU, ok := bU.(RawMatrixer); ok {
bmat := bU.RawMatrix()
if bU, ok := bU.(*Dense); ok {
if restore == nil {
m.checkOverlap(bmat)
m.checkOverlap(bU.mat)
}
switch aU := aU.(type) {
case RawSymmetricer:
amat := aU.RawSymmetric()
case *SymDense:
if bTrans {
c := getWorkspace(bc, br, false)
blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat)
blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat)
return
case RawTriangular:
case *TriDense:
// Trmm updates in place, so copy bU first.
amat := aU.RawTriangular()
if bTrans {
c := getWorkspace(bc, br, false)
var tmp Dense
tmp.SetRawMatrix(bmat)
tmp.SetRawMatrix(bU.mat)
c.Copy(&tmp)
aT := blas.Trans
if aTrans {
aT = blas.NoTrans
}
blas64.Trmm(blas.Right, aT, 1, amat, c.mat)
blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat)
strictCopy(m, c.T())
putWorkspace(c)
return
}
m.Copy(b)
blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat)
return
case *VecDense:
@@ -423,7 +416,7 @@ func (m *Dense) Mul(a, b Matrix) {
if bTrans {
bT = blas.NoTrans
}
blas64.Gemv(bT, 1, bmat, avec, 0, cvec)
blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec)
return
}
// {ar,1} x {1,bc} which is not a vector result.
@@ -434,7 +427,7 @@ func (m *Dense) Mul(a, b Matrix) {
Stride: avec.Inc,
Data: avec.Data,
}
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat)
return
}
}