mirror of
https://github.com/gonum/gonum.git
synced 2025-10-14 19:26:30 +08:00
mat: generalise SymDense RankTwo vector parameters
This commit is contained in:
@@ -341,27 +341,63 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
|
|||||||
// RankTwo performs a symmmetric rank-two update to the matrix a and stores
|
// RankTwo performs a symmmetric rank-two update to the matrix a and stores
|
||||||
// the result in the receiver
|
// the result in the receiver
|
||||||
// m = a + alpha * (x * y' + y * x')
|
// m = a + alpha * (x * y' + y * x')
|
||||||
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) {
|
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y Vector) {
|
||||||
n := s.mat.N
|
n := s.mat.N
|
||||||
if x.Len() != n {
|
xr, xc := x.Dims()
|
||||||
|
if xr != n || xc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
if y.Len() != n {
|
yr, yc := y.Dims()
|
||||||
|
if yr != n || yc != 1 {
|
||||||
panic(ErrShape)
|
panic(ErrShape)
|
||||||
}
|
}
|
||||||
var w SymDense
|
|
||||||
if s == a {
|
|
||||||
w = *s
|
|
||||||
}
|
|
||||||
w.reuseAs(n)
|
|
||||||
if s != a {
|
if s != a {
|
||||||
if rs, ok := a.(RawSymmetricer); ok {
|
if rs, ok := a.(RawSymmetricer); ok {
|
||||||
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||||
}
|
}
|
||||||
w.CopySym(a)
|
|
||||||
}
|
}
|
||||||
blas64.Syr2(alpha, x.mat, y.mat, w.mat)
|
|
||||||
*s = w
|
var xmat, ymat blas64.Vector
|
||||||
|
fast := true
|
||||||
|
xU, _ := untranspose(x)
|
||||||
|
if rv, ok := xU.(RawVectorer); ok {
|
||||||
|
xmat = rv.RawVector()
|
||||||
|
s.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||||
|
} else {
|
||||||
|
fast = false
|
||||||
|
}
|
||||||
|
yU, _ := untranspose(y)
|
||||||
|
if rv, ok := yU.(RawVectorer); ok {
|
||||||
|
ymat = rv.RawVector()
|
||||||
|
s.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
|
||||||
|
} else {
|
||||||
|
fast = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if s != a {
|
||||||
|
if rs, ok := a.(RawSymmetricer); ok {
|
||||||
|
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
|
||||||
|
}
|
||||||
|
s.reuseAs(n)
|
||||||
|
s.CopySym(a)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fast {
|
||||||
|
if s != a {
|
||||||
|
s.reuseAs(n)
|
||||||
|
s.CopySym(a)
|
||||||
|
}
|
||||||
|
blas64.Syr2(alpha, xmat, ymat, s.mat)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
s.reuseAs(n)
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
s.set(i, j, a.At(i, j)+alpha*(x.AtVec(i)*y.AtVec(j)+y.AtVec(i)*x.AtVec(j)))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScaleSym multiplies the elements of a by f, placing the result in the receiver.
|
// ScaleSym multiplies the elements of a by f, placing the result in the receiver.
|
||||||
|
Reference in New Issue
Block a user