mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
mat: Use untransposeExtract in Dense.Mul (#997)
This commit is contained in:
@@ -276,8 +276,8 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
panic(ErrShape)
|
||||
}
|
||||
|
||||
aU, aTrans := untranspose(a)
|
||||
bU, bTrans := untranspose(b)
|
||||
aU, aTrans := untransposeExtract(a)
|
||||
bU, bTrans := untransposeExtract(b)
|
||||
m.reuseAs(ar, bc)
|
||||
var restore func()
|
||||
if m == aU {
|
||||
@@ -300,51 +300,47 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
// temporary memory.
|
||||
// C = A^T * B = (B^T * A)^T
|
||||
// C^T = B^T * A.
|
||||
if aU, ok := aU.(RawMatrixer); ok {
|
||||
amat := aU.RawMatrix()
|
||||
if aU, ok := aU.(*Dense); ok {
|
||||
if restore == nil {
|
||||
m.checkOverlap(amat)
|
||||
m.checkOverlap(aU.mat)
|
||||
}
|
||||
switch bU := bU.(type) {
|
||||
case RawMatrixer:
|
||||
bmat := bU.RawMatrix()
|
||||
case *Dense:
|
||||
if restore == nil {
|
||||
m.checkOverlap(bmat)
|
||||
m.checkOverlap(bU.mat)
|
||||
}
|
||||
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
|
||||
blas64.Gemm(aT, bT, 1, aU.mat, bU.mat, 0, m.mat)
|
||||
return
|
||||
|
||||
case RawSymmetricer:
|
||||
bmat := bU.RawSymmetric()
|
||||
case *SymDense:
|
||||
if aTrans {
|
||||
c := getWorkspace(ac, ar, false)
|
||||
blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
|
||||
blas64.Symm(blas.Left, 1, bU.mat, aU.mat, 0, c.mat)
|
||||
strictCopy(m, c.T())
|
||||
putWorkspace(c)
|
||||
return
|
||||
}
|
||||
blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
|
||||
blas64.Symm(blas.Right, 1, bU.mat, aU.mat, 0, m.mat)
|
||||
return
|
||||
|
||||
case RawTriangular:
|
||||
case *TriDense:
|
||||
// Trmm updates in place, so copy aU first.
|
||||
bmat := bU.RawTriangular()
|
||||
if aTrans {
|
||||
c := getWorkspace(ac, ar, false)
|
||||
var tmp Dense
|
||||
tmp.SetRawMatrix(amat)
|
||||
tmp.SetRawMatrix(aU.mat)
|
||||
c.Copy(&tmp)
|
||||
bT := blas.Trans
|
||||
if bTrans {
|
||||
bT = blas.NoTrans
|
||||
}
|
||||
blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
|
||||
blas64.Trmm(blas.Left, bT, 1, bU.mat, c.mat)
|
||||
strictCopy(m, c.T())
|
||||
putWorkspace(c)
|
||||
return
|
||||
}
|
||||
m.Copy(a)
|
||||
blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
|
||||
blas64.Trmm(blas.Right, bT, 1, bU.mat, m.mat)
|
||||
return
|
||||
|
||||
case *VecDense:
|
||||
@@ -359,54 +355,51 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
Stride: bvec.Inc,
|
||||
Data: bvec.Data,
|
||||
}
|
||||
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
|
||||
blas64.Gemm(aT, bT, 1, aU.mat, bmat, 0, m.mat)
|
||||
return
|
||||
}
|
||||
cvec := blas64.Vector{
|
||||
Inc: m.mat.Stride,
|
||||
Data: m.mat.Data,
|
||||
}
|
||||
blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
|
||||
blas64.Gemv(aT, 1, aU.mat, bvec, 0, cvec)
|
||||
return
|
||||
}
|
||||
}
|
||||
if bU, ok := bU.(RawMatrixer); ok {
|
||||
bmat := bU.RawMatrix()
|
||||
if bU, ok := bU.(*Dense); ok {
|
||||
if restore == nil {
|
||||
m.checkOverlap(bmat)
|
||||
m.checkOverlap(bU.mat)
|
||||
}
|
||||
switch aU := aU.(type) {
|
||||
case RawSymmetricer:
|
||||
amat := aU.RawSymmetric()
|
||||
case *SymDense:
|
||||
if bTrans {
|
||||
c := getWorkspace(bc, br, false)
|
||||
blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat)
|
||||
blas64.Symm(blas.Right, 1, aU.mat, bU.mat, 0, c.mat)
|
||||
strictCopy(m, c.T())
|
||||
putWorkspace(c)
|
||||
return
|
||||
}
|
||||
blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
|
||||
blas64.Symm(blas.Left, 1, aU.mat, bU.mat, 0, m.mat)
|
||||
return
|
||||
|
||||
case RawTriangular:
|
||||
case *TriDense:
|
||||
// Trmm updates in place, so copy bU first.
|
||||
amat := aU.RawTriangular()
|
||||
if bTrans {
|
||||
c := getWorkspace(bc, br, false)
|
||||
var tmp Dense
|
||||
tmp.SetRawMatrix(bmat)
|
||||
tmp.SetRawMatrix(bU.mat)
|
||||
c.Copy(&tmp)
|
||||
aT := blas.Trans
|
||||
if aTrans {
|
||||
aT = blas.NoTrans
|
||||
}
|
||||
blas64.Trmm(blas.Right, aT, 1, amat, c.mat)
|
||||
blas64.Trmm(blas.Right, aT, 1, aU.mat, c.mat)
|
||||
strictCopy(m, c.T())
|
||||
putWorkspace(c)
|
||||
return
|
||||
}
|
||||
m.Copy(b)
|
||||
blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
|
||||
blas64.Trmm(blas.Left, aT, 1, aU.mat, m.mat)
|
||||
return
|
||||
|
||||
case *VecDense:
|
||||
@@ -423,7 +416,7 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
if bTrans {
|
||||
bT = blas.NoTrans
|
||||
}
|
||||
blas64.Gemv(bT, 1, bmat, avec, 0, cvec)
|
||||
blas64.Gemv(bT, 1, bU.mat, avec, 0, cvec)
|
||||
return
|
||||
}
|
||||
// {ar,1} x {1,bc} which is not a vector result.
|
||||
@@ -434,7 +427,7 @@ func (m *Dense) Mul(a, b Matrix) {
|
||||
Stride: avec.Inc,
|
||||
Data: avec.Data,
|
||||
}
|
||||
blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
|
||||
blas64.Gemm(aT, bT, 1, amat, bU.mat, 0, m.mat)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user