mirror of
https://github.com/gonum/gonum.git
synced 2025-10-16 12:10:37 +08:00
mat: generalise Outer vector parameters
This commit is contained in:
@@ -679,13 +679,22 @@ func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) {
|
||||
*m = w
|
||||
}
|
||||
|
||||
// Outer calculates the outer product of x and y, and stores the result
|
||||
// in the receiver.
|
||||
// Outer calculates the outer product of the column vectors x and y,
|
||||
// and stores the result in the receiver.
|
||||
// m = alpha * x * y'
|
||||
// In order to update an existing matrix, see RankOne.
|
||||
func (m *Dense) Outer(alpha float64, x, y *VecDense) {
|
||||
r := x.Len()
|
||||
c := y.Len()
|
||||
func (m *Dense) Outer(alpha float64, x, y Vector) {
|
||||
xr, xc := x.Dims()
|
||||
if xc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
yr, yc := y.Dims()
|
||||
if yc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
|
||||
r := xr
|
||||
c := yr
|
||||
|
||||
// Copied from reuseAs with use replaced by useZeroed
|
||||
// and a final zero of the matrix elements if we pass
|
||||
@@ -707,13 +716,36 @@ func (m *Dense) Outer(alpha float64, x, y *VecDense) {
|
||||
m.capCols = c
|
||||
} else if r != m.mat.Rows || c != m.mat.Cols {
|
||||
panic(ErrShape)
|
||||
}
|
||||
|
||||
var xmat, ymat blas64.Vector
|
||||
fast := true
|
||||
xU, _ := untranspose(x)
|
||||
if rv, ok := xU.(RawVectorer); ok {
|
||||
xmat = rv.RawVector()
|
||||
m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
|
||||
} else {
|
||||
m.checkOverlap(x.asGeneral())
|
||||
m.checkOverlap(y.asGeneral())
|
||||
fast = false
|
||||
}
|
||||
yU, _ := untranspose(y)
|
||||
if rv, ok := yU.(RawVectorer); ok {
|
||||
ymat = rv.RawVector()
|
||||
m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
|
||||
} else {
|
||||
fast = false
|
||||
}
|
||||
|
||||
if fast {
|
||||
for i := 0; i < r; i++ {
|
||||
zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
|
||||
}
|
||||
blas64.Ger(alpha, xmat, ymat, m.mat)
|
||||
return
|
||||
}
|
||||
|
||||
blas64.Ger(alpha, x.mat, y.mat, m.mat)
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1557,7 +1557,7 @@ func TestRankOne(t *testing.T) {
|
||||
// Check with the same matrix
|
||||
a.RankOne(a, test.alpha, NewVecDense(len(test.x), test.x), NewVecDense(len(test.y), test.y))
|
||||
if !Equal(a, want) {
|
||||
t.Errorf("unexpected result for Outer test %d iteration 1: got: %+v want: %+v", i, m, want)
|
||||
t.Errorf("unexpected result for RankOne test %d iteration 1: got: %+v want: %+v", i, m, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1611,6 +1611,21 @@ func TestOuter(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} {
|
||||
method := func(receiver, x, y Matrix) {
|
||||
type outerer interface {
|
||||
Outer(alpha float64, x, y Vector)
|
||||
}
|
||||
m := receiver.(outerer)
|
||||
m.Outer(alpha, x.(Vector), y.(Vector))
|
||||
}
|
||||
denseComparison := func(receiver, x, y *Dense) {
|
||||
receiver.Mul(x, y.T())
|
||||
receiver.Scale(alpha, receiver)
|
||||
}
|
||||
testTwoInput(t, "Outer", &Dense{}, method, denseComparison, legalTypesVectorVector, legalSizeVector, 1e-12)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInverse(t *testing.T) {
|
||||
|
@@ -57,6 +57,11 @@ func legalSizeSolve(ar, ac, br, bc int) bool {
|
||||
return ar == br
|
||||
}
|
||||
|
||||
// legalSizeSameVec returns whether the two matrices are column vectors.
|
||||
func legalSizeVector(_, ac, _, bc int) bool {
|
||||
return ac == 1 && bc == 1
|
||||
}
|
||||
|
||||
// legalSizeSameVec returns whether the two matrices are column vectors of the
|
||||
// same dimension.
|
||||
func legalSizeSameVec(ar, ac, br, bc int) bool {
|
||||
@@ -73,8 +78,8 @@ func isAnySize2(ar, ac, br, bc int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// isAnyVecDense returns true for any column vector sizes.
|
||||
func isAnyVecDense(ar, ac int) bool {
|
||||
// isAnyColumnVector returns true for any column vector sizes.
|
||||
func isAnyColumnVector(ar, ac int) bool {
|
||||
return ac == 1
|
||||
}
|
||||
|
||||
|
@@ -274,7 +274,7 @@ func TestVecDenseScale(t *testing.T) {
|
||||
denseComparison := func(receiver, a *Dense) {
|
||||
receiver.Scale(alpha, a)
|
||||
}
|
||||
testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVector, isAnyVecDense, 0)
|
||||
testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVector, isAnyColumnVector, 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user