mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
mat: generalise Dense RankOne vector parameters
This commit is contained in:
@@ -653,30 +653,56 @@ func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
|
|||||||
// RankOne performs a rank-one update to the matrix a and stores the result
|
// RankOne performs a rank-one update to the matrix a and stores the result
|
||||||
// in the receiver. If a is zero, see Outer.
|
// in the receiver. If a is zero, see Outer.
|
||||||
// m = a + alpha * x * y'
|
// m = a + alpha * x * y'
|
||||||
func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) {
|
func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
|
||||||
ar, ac := a.Dims()
|
ar, ac := a.Dims()
|
||||||
if x.Len() != ar {
|
xr, xc := x.Dims()
|
||||||
|
if xr != ar || xc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
if y.Len() != ac {
|
yr, yc := y.Dims()
|
||||||
|
if yr != ac || yc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.checkOverlap(x.asGeneral())
|
if a != m {
|
||||||
m.checkOverlap(y.asGeneral())
|
aU, _ := untranspose(a)
|
||||||
|
if rm, ok := aU.(RawMatrixer); ok {
|
||||||
var w Dense
|
m.checkOverlap(rm.RawMatrix())
|
||||||
if m == a {
|
}
|
||||||
w = *m
|
|
||||||
}
|
}
|
||||||
w.reuseAs(ar, ac)
|
|
||||||
|
|
||||||
// Copy over to the new memory if necessary
|
var xmat, ymat blas64.Vector
|
||||||
if m != a {
|
fast := true
|
||||||
w.Copy(a)
|
xU, _ := untranspose(x)
|
||||||
|
if rv, ok := xU.(RawVectorer); ok {
|
||||||
|
xmat = rv.RawVector()
|
||||||
|
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||||
|
} else {
|
||||||
|
fast = false
|
||||||
|
}
|
||||||
|
yU, _ := untranspose(y)
|
||||||
|
if rv, ok := yU.(RawVectorer); ok {
|
||||||
|
ymat = rv.RawVector()
|
||||||
|
m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
|
||||||
|
} else {
|
||||||
|
fast = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if fast {
|
||||||
|
if m != a {
|
||||||
|
m.reuseAs(ar, ac)
|
||||||
|
m.Copy(a)
|
||||||
|
}
|
||||||
|
blas64.Ger(alpha, xmat, ymat, m.mat)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.reuseAs(ar, ac)
|
||||||
|
for i := 0; i < ar; i++ {
|
||||||
|
for j := 0; j < ac; j++ {
|
||||||
|
m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
blas64.Ger(alpha, x.mat, y.mat, w.mat)
|
|
||||||
*m = w
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Outer calculates the outer product of the column vectors x and y,
|
// Outer calculates the outer product of the column vectors x and y,
|
||||||
|
Reference in New Issue
Block a user