diff --git a/mat/list_test.go b/mat/list_test.go index d0be00f0..8eefa2c6 100644 --- a/mat/list_test.go +++ b/mat/list_test.go @@ -169,14 +169,20 @@ func legalTypesSym(a, b Matrix) bool { return true } +// legalTypeVector returns whether v is a Vector. +func legalTypeVector(v Matrix) bool { + _, ok := v.(Vector) + return ok +} + // legalTypeVec returns whether v is a *VecDense. -func legalTypeVec(v Matrix) bool { +func legalTypeVecDense(v Matrix) bool { _, ok := v.(*VecDense) return ok } -// legalTypesVecVec returns whether both inputs are Vector -func legalTypesVecVec(a, b Matrix) bool { +// legalTypesVectorVector returns whether both inputs are Vector +func legalTypesVectorVector(a, b Matrix) bool { if _, ok := a.(Vector); !ok { return false } @@ -197,9 +203,16 @@ func legalTypesVecDenseVecDense(a, b Matrix) bool { return true } -// legalTypesNotVecVec returns whether the first input is an arbitrary Matrix +// legalTypesMatrixVector returns whether the first input is an arbitrary Matrix +// and the second input is a Vector. +func legalTypesMatrixVector(a, b Matrix) bool { + _, ok := b.(Vector) + return ok +} + +// legalTypesMatrixVecDense returns whether the first input is an arbitrary Matrix // and the second input is a *VecDense. -func legalTypesNotVecVec(a, b Matrix) bool { +func legalTypesMatrixVecDense(a, b Matrix) bool { _, ok := b.(*VecDense) return ok } diff --git a/mat/matrix_test.go b/mat/matrix_test.go index 371fdc21..35fe6a0f 100644 --- a/mat/matrix_test.go +++ b/mat/matrix_test.go @@ -371,14 +371,18 @@ type basicVector struct { m []float64 } +func (v *basicVector) AtVec(i int) float64 { + if i < 0 || i >= v.Len() { + panic(ErrRowAccess) + } + return v.m[i] +} + func (v *basicVector) At(r, c int) float64 { if c != 0 { panic(ErrColAccess) } - if r < 0 || r >= v.Len() { - panic(ErrRowAccess) - } - return v.m[r] + return v.AtVec(r) } func (v *basicVector) Dims() (r, c int) { @@ -411,7 +415,7 @@ func TestDot(t *testing.T) { } return sum } - testTwoInputFunc(t, "Dot", f, denseComparison, sameAnswerFloatApproxTol(1e-12), legalTypesVecVec, legalSizeSameVec) + testTwoInputFunc(t, "Dot", f, denseComparison, sameAnswerFloatApproxTol(1e-12), legalTypesVectorVector, legalSizeSameVec) } func TestEqual(t *testing.T) { diff --git a/mat/solve_test.go b/mat/solve_test.go index 4b74a005..d6bd6164 100644 --- a/mat/solve_test.go +++ b/mat/solve_test.go @@ -293,5 +293,5 @@ func TestSolveVec(t *testing.T) { denseComparison := func(receiver, a, b *Dense) { receiver.Solve(a, b) } - testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesNotVecVec, legalSizeSolve, 1e-12) + testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVecDense, legalSizeSolve, 1e-12) } diff --git a/mat/vector_test.go b/mat/vector_test.go index 505357ed..be9484af 100644 --- a/mat/vector_test.go +++ b/mat/vector_test.go @@ -201,7 +201,7 @@ func TestVecDenseMul(t *testing.T) { } return legal } - testTwoInput(t, "MulVec", &VecDense{}, method, denseComparison, legalTypesNotVecVec, legalSizeMulVec, 1e-14) + testTwoInput(t, "MulVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeMulVec, 1e-14) } func TestVecDenseScale(t *testing.T) { @@ -274,7 +274,7 @@ func TestVecDenseScale(t *testing.T) { denseComparison := func(receiver, a *Dense) { receiver.Scale(alpha, a) } - testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVec, isAnyVecDense, 0) + testOneInput(t, "ScaleVec", &VecDense{}, method, denseComparison, legalTypeVector, isAnyVecDense, 0) } } @@ -292,7 +292,7 @@ func TestVecDenseAddScaled(t *testing.T) { sb.Scale(alpha, b) receiver.Add(a, &sb) } - testTwoInput(t, "AddScaledVec", &VecDense{}, method, denseComparison, legalTypesVecDenseVecDense, legalSizeSameVec, 1e-14) + testTwoInput(t, "AddScaledVec", &VecDense{}, method, denseComparison, legalTypesVectorVector, legalSizeSameVec, 1e-14) } }