diff --git a/mat/dense_arithmetic.go b/mat/dense_arithmetic.go index e808888e..a08d7ab6 100644 --- a/mat/dense_arithmetic.go +++ b/mat/dense_arithmetic.go @@ -679,13 +679,22 @@ func (m *Dense) RankOne(a Matrix, alpha float64, x, y *VecDense) { *m = w } -// Outer calculates the outer product of x and y, and stores the result -// in the receiver. +// Outer calculates the outer product of the column vectors x and y, +// and stores the result in the receiver. // m = alpha * x * y' // In order to update an existing matrix, see RankOne. -func (m *Dense) Outer(alpha float64, x, y *VecDense) { - r := x.Len() - c := y.Len() +func (m *Dense) Outer(alpha float64, x, y Vector) { + xr, xc := x.Dims() + if xc != 1 { + panic(ErrShape) + } + yr, yc := y.Dims() + if yc != 1 { + panic(ErrShape) + } + + r := xr + c := yr // Copied from reuseAs with use replaced by useZeroed // and a final zero of the matrix elements if we pass @@ -707,13 +716,36 @@ func (m *Dense) Outer(alpha float64, x, y *VecDense) { m.capCols = c } else if r != m.mat.Rows || c != m.mat.Cols { panic(ErrShape) + } + + var xmat, ymat blas64.Vector + fast := true + xU, _ := untranspose(x) + if rv, ok := xU.(RawVectorer); ok { + xmat = rv.RawVector() + m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral()) } else { - m.checkOverlap(x.asGeneral()) - m.checkOverlap(y.asGeneral()) + fast = false + } + yU, _ := untranspose(y) + if rv, ok := yU.(RawVectorer); ok { + ymat = rv.RawVector() + m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral()) + } else { + fast = false + } + + if fast { for i := 0; i < r; i++ { zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c]) } + blas64.Ger(alpha, xmat, ymat, m.mat) + return } - blas64.Ger(alpha, x.mat, y.mat, m.mat) + for i := 0; i < r; i++ { + for j := 0; j < c; j++ { + m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j)) + } + } } diff --git a/mat/dense_test.go b/mat/dense_test.go index fc15ea26..dce3076b 100644 --- a/mat/dense_test.go +++ b/mat/dense_test.go @@ -1557,7 +1557,7 @@ func TestRankOne(t *testing.T) { // Check with the same matrix a.RankOne(a, test.alpha, NewVecDense(len(test.x), test.x), NewVecDense(len(test.y), test.y)) if !Equal(a, want) { - t.Errorf("unexpected result for Outer test %d iteration 1: got: %+v want: %+v", i, m, want) + t.Errorf("unexpected result for RankOne test %d iteration 1: got: %+v want: %+v", i, m, want) } } } @@ -1611,6 +1611,21 @@ func TestOuter(t *testing.T) { } } } + + for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} { + method := func(receiver, x, y Matrix) { + type outerer interface { + Outer(alpha float64, x, y Vector) + } + m := receiver.(outerer) + m.Outer(alpha, x.(Vector), y.(Vector)) + } + denseComparison := func(receiver, x, y *Dense) { + receiver.Mul(x, y.T()) + receiver.Scale(alpha, receiver) + } + testTwoInput(t, "Outer", &Dense{}, method, denseComparison, legalTypesVectorVector, legalSizeVector, 1e-12) + } } func TestInverse(t *testing.T) { diff --git a/mat/list_test.go b/mat/list_test.go index 8eefa2c6..feb72211 100644 --- a/mat/list_test.go +++ b/mat/list_test.go @@ -57,6 +57,11 @@ func legalSizeSolve(ar, ac, br, bc int) bool { return ar == br } +// legalSizeSameVec returns whether the two matrices are column vectors. +func legalSizeVector(_, ac, _, bc int) bool { + return ac == 1 && bc == 1 +} + // legalSizeSameVec returns whether the two matrices are column vectors of the // same dimension. func legalSizeSameVec(ar, ac, br, bc int) bool { @@ -73,8 +78,8 @@ func isAnySize2(ar, ac, br, bc int) bool { return true } -// isAnyVecDense returns true for any column vector sizes. -func isAnyVecDense(ar, ac int) bool { +// isAnyColumnVector returns true for any column vector sizes. +func isAnyColumnVector(ar, ac int) bool { return ac == 1 } diff --git a/mat/vector_test.go b/mat/vector_test.go index be9484af..6e08672f 100644 --- a/mat/vector_test.go +++ b/mat/vector_test.go @@ -274,7 +274,7 @@ func TestVecDenseScale(t *testing.T) { denseComparison := func(receiver, a *Dense) { receiver.Scale(alpha, a) } - testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVector, isAnyVecDense, 0) + testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVector, isAnyColumnVector, 0) } }