mat: replace untranspose with untransposeExtract (#1036)

* mat: replace untranspose with untransposeExtract

This leaves one remaining use in TriDense but that case is a bit more complicated
This commit is contained in:
Brendan Tracey
2019-07-24 14:37:15 +01:00
committed by GitHub
parent 2122b538b6
commit a8659125a9
4 changed files with 103 additions and 103 deletions

View File

@@ -21,13 +21,13 @@ func (m *Dense) Add(a, b Matrix) {
panic(ErrShape)
}
aU, _ := untranspose(a)
bU, _ := untranspose(b)
aU, _ := untransposeExtract(a)
bU, _ := untransposeExtract(b)
m.reuseAs(ar, ac)
if arm, ok := a.(RawMatrixer); ok {
if brm, ok := b.(RawMatrixer); ok {
amat, bmat := arm.RawMatrix(), brm.RawMatrix()
if arm, ok := a.(*Dense); ok {
if brm, ok := b.(*Dense); ok {
amat, bmat := arm.mat, brm.mat
if m != aU {
m.checkOverlap(amat)
}
@@ -70,13 +70,13 @@ func (m *Dense) Sub(a, b Matrix) {
panic(ErrShape)
}
aU, _ := untranspose(a)
bU, _ := untranspose(b)
aU, _ := untransposeExtract(a)
bU, _ := untransposeExtract(b)
m.reuseAs(ar, ac)
if arm, ok := a.(RawMatrixer); ok {
if brm, ok := b.(RawMatrixer); ok {
amat, bmat := arm.RawMatrix(), brm.RawMatrix()
if arm, ok := a.(*Dense); ok {
if brm, ok := b.(*Dense); ok {
amat, bmat := arm.mat, brm.mat
if m != aU {
m.checkOverlap(amat)
}
@@ -120,13 +120,13 @@ func (m *Dense) MulElem(a, b Matrix) {
panic(ErrShape)
}
aU, _ := untranspose(a)
bU, _ := untranspose(b)
aU, _ := untransposeExtract(a)
bU, _ := untransposeExtract(b)
m.reuseAs(ar, ac)
if arm, ok := a.(RawMatrixer); ok {
if brm, ok := b.(RawMatrixer); ok {
amat, bmat := arm.RawMatrix(), brm.RawMatrix()
if arm, ok := a.(*Dense); ok {
if brm, ok := b.(*Dense); ok {
amat, bmat := arm.mat, brm.mat
if m != aU {
m.checkOverlap(amat)
}
@@ -170,13 +170,13 @@ func (m *Dense) DivElem(a, b Matrix) {
panic(ErrShape)
}
aU, _ := untranspose(a)
bU, _ := untranspose(b)
aU, _ := untransposeExtract(a)
bU, _ := untransposeExtract(b)
m.reuseAs(ar, ac)
if arm, ok := a.(RawMatrixer); ok {
if brm, ok := b.(RawMatrixer); ok {
amat, bmat := arm.RawMatrix(), brm.RawMatrix()
if arm, ok := a.(*Dense); ok {
if brm, ok := b.(*Dense); ok {
amat, bmat := arm.mat, brm.mat
if m != aU {
m.checkOverlap(amat)
}
@@ -221,11 +221,11 @@ func (m *Dense) Inverse(a Matrix) error {
panic(ErrSquare)
}
m.reuseAs(a.Dims())
aU, aTrans := untranspose(a)
aU, aTrans := untransposeExtract(a)
switch rm := aU.(type) {
case RawMatrixer:
case *Dense:
if m != aU || aTrans {
if m == aU || m.checkOverlap(rm.RawMatrix()) {
if m == aU || m.checkOverlap(rm.mat) {
tmp := getWorkspace(r, c, false)
tmp.Copy(a)
m.Copy(tmp)
@@ -681,9 +681,9 @@ func (m *Dense) Scale(f float64, a Matrix) {
m.reuseAs(ar, ac)
aU, aTrans := untranspose(a)
if rm, ok := aU.(RawMatrixer); ok {
amat := rm.RawMatrix()
aU, aTrans := untransposeExtract(a)
if rm, ok := aU.(*Dense); ok {
amat := rm.mat
if m == aU || m.checkOverlap(amat) {
var restore func()
m, restore = m.isolatedWorkspace(a)
@@ -721,9 +721,9 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
m.reuseAs(ar, ac)
aU, aTrans := untranspose(a)
if rm, ok := aU.(RawMatrixer); ok {
amat := rm.RawMatrix()
aU, aTrans := untransposeExtract(a)
if rm, ok := aU.(*Dense); ok {
amat := rm.mat
if m == aU || m.checkOverlap(amat) {
var restore func()
m, restore = m.isolatedWorkspace(a)
@@ -767,26 +767,26 @@ func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
}
if a != m {
aU, _ := untranspose(a)
if rm, ok := aU.(RawMatrixer); ok {
aU, _ := untransposeExtract(a)
if rm, ok := aU.(*Dense); ok {
m.checkOverlap(rm.RawMatrix())
}
}
var xmat, ymat blas64.Vector
fast := true
xU, _ := untranspose(x)
if rv, ok := xU.(RawVectorer); ok {
xU, _ := untransposeExtract(x)
if rv, ok := xU.(*VecDense); ok {
r, c := xU.Dims()
xmat = rv.RawVector()
xmat = rv.mat
m.checkOverlap(generalFromVector(xmat, r, c))
} else {
fast = false
}
yU, _ := untranspose(y)
if rv, ok := yU.(RawVectorer); ok {
yU, _ := untransposeExtract(y)
if rv, ok := yU.(*VecDense); ok {
r, c := yU.Dims()
ymat = rv.RawVector()
ymat = rv.mat
m.checkOverlap(generalFromVector(ymat, r, c))
} else {
fast = false
@@ -840,18 +840,18 @@ func (m *Dense) Outer(alpha float64, x, y Vector) {
var xmat, ymat blas64.Vector
fast := true
xU, _ := untranspose(x)
if rv, ok := xU.(RawVectorer); ok {
xU, _ := untransposeExtract(x)
if rv, ok := xU.(*VecDense); ok {
r, c := xU.Dims()
xmat = rv.RawVector()
xmat = rv.mat
m.checkOverlap(generalFromVector(xmat, r, c))
} else {
fast = false
}
yU, _ := untranspose(y)
if rv, ok := yU.(RawVectorer); ok {
yU, _ := untransposeExtract(y)
if rv, ok := yU.(*VecDense); ok {
r, c := yU.Dims()
ymat = rv.RawVector()
ymat = rv.mat
m.checkOverlap(generalFromVector(ymat, r, c))
} else {
fast = false