mat: generalise Dense RankOne vector parameters

This commit is contained in:
kortschak
2018-01-03 13:11:19 +10:30
committed by Dan Kortschak
parent 7d975f4c67
commit 146e16d5b4

View File

@@ -653,30 +653,56 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
// RankOne performs a rank-one update to the matrix a and stores the result
// in the receiver. If a is zero, see Outer.
// m = a + alpha * x * y'
func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) {
func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
ar, ac := a.Dims()
if x.Len() != ar {
xr, xc := x.Dims()
if xr != ar || xc != 1 {
panic(ErrShape)
}
if y.Len() != ac {
yr, yc := y.Dims()
if yr != ac || yc != 1 {
panic(ErrShape)
}
m.checkOverlap(x.asGeneral())
m.checkOverlap(y.asGeneral())
var w Dense
if m == a {
w = *m
if a != m {
aU, _ := untranspose(a)
if rm, ok := aU.(RawMatrixer); ok {
m.checkOverlap(rm.RawMatrix())
}
}
w.reuseAs(ar, ac)
// Copy over to the new memory if necessary
if m != a {
w.Copy(a)
var xmat, ymat blas64.Vector
fast := true
xU, _ := untranspose(x)
if rv, ok := xU.(RawVectorer); ok {
xmat = rv.RawVector()
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
} else {
fast = false
}
yU, _ := untranspose(y)
if rv, ok := yU.(RawVectorer); ok {
ymat = rv.RawVector()
m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
} else {
fast = false
}
if fast {
if m != a {
m.reuseAs(ar, ac)
m.Copy(a)
}
blas64.Ger(alpha, xmat, ymat, m.mat)
return
}
m.reuseAs(ar, ac)
for i := 0; i < ar; i++ {
for j := 0; j < ac; j++ {
m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
}
}
blas64.Ger(alpha, x.mat, y.mat, w.mat)
*m = w
}
// Outer calculates the outer product of the column vectors x and y,