mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
mat: generalise Dense RankOne vector parameters
This commit is contained in:
@@ -653,30 +653,56 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
|
||||
// RankOne performs a rank-one update to the matrix a and stores the result
|
||||
// in the receiver. If a is zero, see Outer.
|
||||
// m = a + alpha * x * y'
|
||||
func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) {
|
||||
func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
|
||||
ar, ac := a.Dims()
|
||||
if x.Len() != ar {
|
||||
xr, xc := x.Dims()
|
||||
if xr != ar || xc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
if y.Len() != ac {
|
||||
yr, yc := y.Dims()
|
||||
if yr != ac || yc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
|
||||
m.checkOverlap(x.asGeneral())
|
||||
m.checkOverlap(y.asGeneral())
|
||||
|
||||
var w Dense
|
||||
if m == a {
|
||||
w = *m
|
||||
if a != m {
|
||||
aU, _ := untranspose(a)
|
||||
if rm, ok := aU.(RawMatrixer); ok {
|
||||
m.checkOverlap(rm.RawMatrix())
|
||||
}
|
||||
}
|
||||
w.reuseAs(ar, ac)
|
||||
|
||||
// Copy over to the new memory if necessary
|
||||
if m != a {
|
||||
w.Copy(a)
|
||||
var xmat, ymat blas64.Vector
|
||||
fast := true
|
||||
xU, _ := untranspose(x)
|
||||
if rv, ok := xU.(RawVectorer); ok {
|
||||
xmat = rv.RawVector()
|
||||
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
yU, _ := untranspose(y)
|
||||
if rv, ok := yU.(RawVectorer); ok {
|
||||
ymat = rv.RawVector()
|
||||
m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
|
||||
if fast {
|
||||
if m != a {
|
||||
m.reuseAs(ar, ac)
|
||||
m.Copy(a)
|
||||
}
|
||||
blas64.Ger(alpha, xmat, ymat, m.mat)
|
||||
return
|
||||
}
|
||||
|
||||
m.reuseAs(ar, ac)
|
||||
for i := 0; i < ar; i++ {
|
||||
for j := 0; j < ac; j++ {
|
||||
m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
|
||||
}
|
||||
}
|
||||
blas64.Ger(alpha, x.mat, y.mat, w.mat)
|
||||
*m = w
|
||||
}
|
||||
|
||||
// Outer calculates the outer product of the column vectors x and y,
|
||||
|
Reference in New Issue
Block a user