diff --git a/mat/dense_arithmetic.go b/mat/dense_arithmetic.go index a08d7ab6..98de7184 100644 --- a/mat/dense_arithmetic.go +++ b/mat/dense_arithmetic.go @@ -653,30 +653,56 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) { // RankOne performs a rank-one update to the matrix a and stores the result // in the receiver. If a is zero, see Outer. // m = a + alpha * x * y' -func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) { +func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) { ar, ac := a.Dims() - if x.Len() != ar { + xr, xc := x.Dims() + if xr != ar || xc != 1 { panic(ErrShape) } - if y.Len() != ac { + yr, yc := y.Dims() + if yr != ac || yc != 1 { panic(ErrShape) } - m.checkOverlap(x.asGeneral()) - m.checkOverlap(y.asGeneral()) - - var w Dense - if m == a { - w = *m + if a != m { + aU, _ := untranspose(a) + if rm, ok := aU.(RawMatrixer); ok { + m.checkOverlap(rm.RawMatrix()) + } } - w.reuseAs(ar, ac) - // Copy over to the new memory if necessary - if m != a { - w.Copy(a) + var xmat, ymat blas64.Vector + fast := true + xU, _ := untranspose(x) + if rv, ok := xU.(RawVectorer); ok { + xmat = rv.RawVector() + m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral()) + } else { + fast = false + } + yU, _ := untranspose(y) + if rv, ok := yU.(RawVectorer); ok { + ymat = rv.RawVector() + m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral()) + } else { + fast = false + } + + if fast { + if m != a { + m.reuseAs(ar, ac) + m.Copy(a) + } + blas64.Ger(alpha, xmat, ymat, m.mat) + return + } + + m.reuseAs(ar, ac) + for i := 0; i < ar; i++ { + for j := 0; j < ac; j++ { + m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j)) + } } - blas64.Ger(alpha, x.mat, y.mat, w.mat) - *m = w } // Outer calculates the outer product of the column vectors x and y,