mat: ensure number of elements to be copied is sanitised for blas64

This commit is contained in:
Dan Kortschak
2019-07-04 15:39:34 +09:30
parent 536a303fd6
commit 0a381ca743
2 changed files with 31 additions and 1 deletions

View File

@@ -226,7 +226,11 @@ func (v *VecDense) CopyVec(a Vector) int {
return n
}
if r, ok := a.(RawVectorer); ok {
blas64.Copy(r.RawVector(), v.mat)
src := r.RawVector()
src.N = n
dst := v.mat
dst.N = n
blas64.Copy(src, dst)
return n
}
for i := 0; i < n; i++ {

View File

@@ -310,6 +310,32 @@ func TestVecDenseScale(t *testing.T) {
}
}
func TestCopyVec(t *testing.T) {
for i, test := range []struct {
src *VecDense
dst *VecDense
want *VecDense
wantN int
}{
{src: NewVecDense(1, nil), dst: NewVecDense(1, nil), want: NewVecDense(1, nil), wantN: 1},
{src: NewVecDense(3, []float64{1, 2, 3}), dst: NewVecDense(2, []float64{-1, -2}), want: NewVecDense(2, []float64{1, 2}), wantN: 2},
{src: NewVecDense(2, []float64{1, 2}), dst: NewVecDense(3, []float64{-1, -2, -3}), want: NewVecDense(3, []float64{1, 2, -3}), wantN: 2},
} {
got := test.dst
var n int
panicked, message := panics(func() { n = got.CopyVec(test.src) })
if panicked {
t.Errorf("unexpected panic during vector copy for test %d: %s", i, message)
}
if !Equal(got, test.want) {
t.Errorf("test %d: unexpected result CopyVec:\ngot: %v\nwant:%v", i, got, test.want)
}
if n != test.wantN {
t.Errorf("test %d: unexpected result number of elements copied: got:%d want:%d", i, n, test.wantN)
}
}
}
func TestVecDenseAddScaled(t *testing.T) {
for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} {
method := func(receiver, a, b Matrix) {