mat: generalise Outer vector parameters

This commit is contained in:
kortschak
2018-01-03 12:49:58 +10:30
committed by Dan Kortschak
parent 6861c60a47
commit 7d975f4c67
4 changed files with 64 additions and 12 deletions

View File

@@ -679,13 +679,22 @@ func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) {
*m = w
}
// Outer calculates the outer product of x and y, and stores the result
// in the receiver.
// Outer calculates the outer product of the column vectors x and y,
// and stores the result in the receiver.
// m = alpha * x * y'
// In order to update an existing matrix, see RankOne.
func (m *Dense) Outer(alpha float64, x, y *VecDense) {
r := x.Len()
c := y.Len()
func (m *Dense) Outer(alpha float64, x, y Vector) {
xr, xc := x.Dims()
if xc != 1 {
panic(ErrShape)
}
yr, yc := y.Dims()
if yc != 1 {
panic(ErrShape)
}
r := xr
c := yr
// Copied from reuseAs with use replaced by useZeroed
// and a final zero of the matrix elements if we pass
@@ -707,13 +716,36 @@ func (m *Dense) Outer(alpha float64, x, y *VecDense) {
m.capCols = c
} else if r != m.mat.Rows || c != m.mat.Cols {
panic(ErrShape)
}
var xmat, ymat blas64.Vector
fast := true
xU, _ := untranspose(x)
if rv, ok := xU.(RawVectorer); ok {
xmat = rv.RawVector()
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
} else {
m.checkOverlap(x.asGeneral())
m.checkOverlap(y.asGeneral())
fast = false
}
yU, _ := untranspose(y)
if rv, ok := yU.(RawVectorer); ok {
ymat = rv.RawVector()
m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
} else {
fast = false
}
if fast {
for i := 0; i < r; i++ {
zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
}
blas64.Ger(alpha, xmat, ymat, m.mat)
return
}
blas64.Ger(alpha, x.mat, y.mat, m.mat)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))
}
}
}