mat: generalise SymDense RankTwo vector parameters

This commit is contained in:
kortschak
2018-01-03 13:47:00 +10:30
committed by Dan Kortschak
parent 003363ab06
commit 463c9f711f

View File

@@ -341,27 +341,63 @@ func (s *SymDense) SymOuterK(alpha float64, x Matrix) {
// RankTwo performs a symmmetric rank-two update to the matrix a and stores // RankTwo performs a symmmetric rank-two update to the matrix a and stores
// the result in the receiver // the result in the receiver
// m = a + alpha * (x * y' + y * x') // m = a + alpha * (x * y' + y * x')
func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y *VecDense) { func (s *SymDense) RankTwo(a Symmetric, alpha float64, x, y Vector) {
n := s.mat.N n := s.mat.N
if x.Len() != n { xr, xc := x.Dims()
if xr != n || xc != 1 {
panic(ErrShape) panic(ErrShape)
} }
if y.Len() != n { yr, yc := y.Dims()
if yr != n || yc != 1 {
panic(ErrShape) panic(ErrShape)
} }
var w SymDense
if s == a {
w = *s
}
w.reuseAs(n)
if s != a { if s != a {
if rs, ok := a.(RawSymmetricer); ok { if rs, ok := a.(RawSymmetricer); ok {
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric())) s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
} }
w.CopySym(a)
} }
blas64.Syr2(alpha, x.mat, y.mat, w.mat)
*s = w var xmat, ymat blas64.Vector
fast := true
xU, _ := untranspose(x)
if rv, ok := xU.(RawVectorer); ok {
xmat = rv.RawVector()
s.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
} else {
fast = false
}
yU, _ := untranspose(y)
if rv, ok := yU.(RawVectorer); ok {
ymat = rv.RawVector()
s.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
} else {
fast = false
}
if s != a {
if rs, ok := a.(RawSymmetricer); ok {
s.checkOverlap(generalFromSymmetric(rs.RawSymmetric()))
}
s.reuseAs(n)
s.CopySym(a)
}
if fast {
if s != a {
s.reuseAs(n)
s.CopySym(a)
}
blas64.Syr2(alpha, xmat, ymat, s.mat)
return
}
for i := 0; i < n; i++ {
s.reuseAs(n)
for j := i; j < n; j++ {
s.set(i, j, a.At(i, j)+alpha*(x.AtVec(i)*y.AtVec(j)+y.AtVec(i)*x.AtVec(j)))
}
}
} }
// ScaleSym multiplies the elements of a by f, placing the result in the receiver. // ScaleSym multiplies the elements of a by f, placing the result in the receiver.