mirror of
https://github.com/gonum/gonum.git
synced 2025-10-14 19:26:30 +08:00
mat: generalise SymDense RankTwo vector parameters
This commit is contained in:
@@ -341,27 +341,63 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
|
||||
// RankTwo performs a symmmetric rank-two update to the matrix a and stores
|
||||
// the result in the receiver
|
||||
// m = a + alpha * (x * y' + y * x')
|
||||
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) {
|
||||
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y Vector) {
|
||||
n := s.mat.N
|
||||
if x.Len() != n {
|
||||
xr, xc := x.Dims()
|
||||
if xr != n || xc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
if y.Len() != n {
|
||||
yr, yc := y.Dims()
|
||||
if yr != n || yc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
var w SymDense
|
||||
if s == a {
|
||||
w = *s
|
||||
}
|
||||
w.reuseAs(n)
|
||||
|
||||
if s != a {
|
||||
if rs, ok := a.(RawSymmetricer); ok {
|
||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||
}
|
||||
w.CopySym(a)
|
||||
}
|
||||
blas64.Syr2(alpha, x.mat, y.mat, w.mat)
|
||||
*s = w
|
||||
|
||||
var xmat, ymat blas64.Vector
|
||||
fast := true
|
||||
xU, _ := untranspose(x)
|
||||
if rv, ok := xU.(RawVectorer); ok {
|
||||
xmat = rv.RawVector()
|
||||
s.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
yU, _ := untranspose(y)
|
||||
if rv, ok := yU.(RawVectorer); ok {
|
||||
ymat = rv.RawVector()
|
||||
s.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
|
||||
if s != a {
|
||||
if rs, ok := a.(RawSymmetricer); ok {
|
||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||
}
|
||||
s.reuseAs(n)
|
||||
s.CopySym(a)
|
||||
}
|
||||
|
||||
if fast {
|
||||
if s != a {
|
||||
s.reuseAs(n)
|
||||
s.CopySym(a)
|
||||
}
|
||||
blas64.Syr2(alpha, xmat, ymat, s.mat)
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
s.reuseAs(n)
|
||||
for j := i; j < n; j++ {
|
||||
s.set(i, j, a.At(i, j)+alpha*(x.AtVec(i)*y.AtVec(j)+y.AtVec(i)*x.AtVec(j)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ScaleSym multiplies the elements of a by f, placing the result in the receiver.
|
||||
|
Reference in New Issue
Block a user