mat: generalise basic arithmetic vector operations

This commit is contained in:
Dan Kortschak
2017-12-25 06:12:00 +10:30
committed by GitHub
parent b52122b771
commit 6e57d606a5
3 changed files with 287 additions and 153 deletions

View File

@@ -315,17 +315,17 @@ var vectorData = []struct {
}, },
{ {
raw: []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@"), raw: []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@"),
want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(0, 3), want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(0, 3).(*VecDense),
eq: Equal, eq: Equal,
}, },
{ {
raw: []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@\x00\x00\x00\x00\x00\x00\x10@"), raw: []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@\x00\x00\x00\x00\x00\x00\x10@"),
want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(1, 4), want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(1, 4).(*VecDense),
eq: Equal, eq: Equal,
}, },
{ {
raw: []byte("\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@\x00\x00\x00\x00\x00\x00\x10@\x00\x00\x00\x00\x00\x00\x14@\x00\x00\x00\x00\x00\x00\x18@\x00\x00\x00\x00\x00\x00\x1c@\x00\x00\x00\x00\x00\x00 @"), raw: []byte("\b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\b@\x00\x00\x00\x00\x00\x00\x10@\x00\x00\x00\x00\x00\x00\x14@\x00\x00\x00\x00\x00\x00\x18@\x00\x00\x00\x00\x00\x00\x1c@\x00\x00\x00\x00\x00\x00 @"),
want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(0, 8), want: NewVecDense(9, []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}).SliceVec(0, 8).(*VecDense),
eq: Equal, eq: Equal,
}, },
{ {

View File

@@ -97,11 +97,11 @@ func NewVecDense(n int, data []float64) *VecDense {
} }
} }
// SliceVec returns a new VecDense that shares backing data with the receiver. // SliceVec returns a new Vector that shares backing data with the receiver.
// The returned matrix starts at i of the receiver and extends k-i elements. // The returned matrix starts at i of the receiver and extends k-i elements.
// SliceVec panics with ErrIndexOutOfRange if the slice is outside the capacity // SliceVec panics with ErrIndexOutOfRange if the slice is outside the capacity
// of the receiver. // of the receiver.
func (v *VecDense) SliceVec(i, k int) *VecDense { func (v *VecDense) SliceVec(i, k int) Vector {
if i < 0 || k <= i || v.Cap() < k { if i < 0 || k <= i || v.Cap() < k {
panic(ErrIndexOutOfRange) panic(ErrIndexOutOfRange)
} }
@@ -196,36 +196,55 @@ func (v *VecDense) RawVector() blas64.Vector {
// CopyVec makes a copy of elements of a into the receiver. It is similar to the // CopyVec makes a copy of elements of a into the receiver. It is similar to the
// built-in copy; it copies as much as the overlap between the two vectors and // built-in copy; it copies as much as the overlap between the two vectors and
// returns the number of elements it copied. // returns the number of elements it copied.
func (v *VecDense) CopyVec(a *VecDense) int { func (v *VecDense) CopyVec(a Vector) int {
n := min(v.Len(), a.Len()) n := min(v.Len(), a.Len())
if v != a { if v == a {
blas64.Copy(n, a.mat, v.mat) return n
}
if r, ok := a.(RawVectorer); ok {
blas64.Copy(n, r.RawVector(), v.mat)
return n
}
for i := 0; i < n; i++ {
v.setVec(i, a.AtVec(i))
} }
return n return n
} }
// ScaleVec scales the vector a by alpha, placing the result in the receiver. // ScaleVec scales the vector a by alpha, placing the result in the receiver.
func (v *VecDense) ScaleVec(alpha float64, a *VecDense) { func (v *VecDense) ScaleVec(alpha float64, a Vector) {
n := a.Len() n := a.Len()
if v != a {
v.reuseAs(n) if v == a {
if v.mat.Inc == 1 && a.mat.Inc == 1 {
f64.ScalUnitaryTo(v.mat.Data, alpha, a.mat.Data)
return
}
f64.ScalIncTo(v.mat.Data, uintptr(v.mat.Inc),
alpha, a.mat.Data, uintptr(n), uintptr(a.mat.Inc))
return
}
if v.mat.Inc == 1 { if v.mat.Inc == 1 {
f64.ScalUnitary(alpha, v.mat.Data) f64.ScalUnitary(alpha, v.mat.Data)
return return
} }
f64.ScalInc(alpha, v.mat.Data, uintptr(n), uintptr(v.mat.Inc)) f64.ScalInc(alpha, v.mat.Data, uintptr(n), uintptr(v.mat.Inc))
return
}
v.reuseAs(n)
if rv, ok := a.(RawVectorer); ok {
mat := rv.RawVector()
v.checkOverlap(mat)
if v.mat.Inc == 1 && mat.Inc == 1 {
f64.ScalUnitaryTo(v.mat.Data, alpha, mat.Data)
return
}
f64.ScalIncTo(v.mat.Data, uintptr(v.mat.Inc),
alpha, mat.Data, uintptr(n), uintptr(mat.Inc))
return
}
for i := 0; i < n; i++ {
v.setVec(i, alpha*a.AtVec(i))
}
} }
// AddScaledVec adds the vectors a and alpha*b, placing the result in the receiver. // AddScaledVec adds the vectors a and alpha*b, placing the result in the receiver.
func (v *VecDense) AddScaledVec(a *VecDense, alpha float64, b *VecDense) { func (v *VecDense) AddScaledVec(a Vector, alpha float64, b Vector) {
if alpha == 1 { if alpha == 1 {
v.AddVec(a, b) v.AddVec(a, b)
return return
@@ -242,42 +261,63 @@ func (v *VecDense) AddScaledVec(a *VecDense, alpha float64, b *VecDense) {
panic(ErrShape) panic(ErrShape)
} }
var amat, bmat blas64.Vector
fast := true
aU, _ := untranspose(a)
if rv, ok := aU.(RawVectorer); ok {
amat = rv.RawVector()
if v != a { if v != a {
v.checkOverlap(a.mat) v.checkOverlap(amat)
} }
} else {
fast = false
}
bU, _ := untranspose(b)
if rv, ok := bU.(RawVectorer); ok {
bmat = rv.RawVector()
if v != b { if v != b {
v.checkOverlap(b.mat) v.checkOverlap(bmat)
}
} else {
fast = false
} }
v.reuseAs(ar) v.reuseAs(ar)
switch { switch {
case alpha == 0: // v <- a case alpha == 0: // v <- a
if v == a {
return
}
v.CopyVec(a) v.CopyVec(a)
case v == a && v == b: // v <- v + alpha * v = (alpha + 1) * v case v == a && v == b: // v <- v + alpha * v = (alpha + 1) * v
blas64.Scal(ar, alpha+1, v.mat) blas64.Scal(ar, alpha+1, v.mat)
case !fast: // v <- a + alpha * b without blas64 support.
for i := 0; i < ar; i++ {
v.setVec(i, a.AtVec(i)+alpha*b.AtVec(i))
}
case v == a && v != b: // v <- v + alpha * b case v == a && v != b: // v <- v + alpha * b
if v.mat.Inc == 1 && b.mat.Inc == 1 { if v.mat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case. // Fast path for a common case.
f64.AxpyUnitaryTo(v.mat.Data, alpha, b.mat.Data, a.mat.Data) f64.AxpyUnitaryTo(v.mat.Data, alpha, bmat.Data, amat.Data)
} else { } else {
f64.AxpyInc(alpha, b.mat.Data, v.mat.Data, f64.AxpyInc(alpha, bmat.Data, v.mat.Data,
uintptr(ar), uintptr(b.mat.Inc), uintptr(v.mat.Inc), 0, 0) uintptr(ar), uintptr(bmat.Inc), uintptr(v.mat.Inc), 0, 0)
} }
default: // v <- a + alpha * b or v <- a + alpha * v default: // v <- a + alpha * b or v <- a + alpha * v
if v.mat.Inc == 1 && a.mat.Inc == 1 && b.mat.Inc == 1 { if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case. // Fast path for a common case.
f64.AxpyUnitaryTo(v.mat.Data, alpha, b.mat.Data, a.mat.Data) f64.AxpyUnitaryTo(v.mat.Data, alpha, bmat.Data, amat.Data)
} else { } else {
f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0, f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
alpha, b.mat.Data, a.mat.Data, alpha, bmat.Data, amat.Data,
uintptr(ar), uintptr(b.mat.Inc), uintptr(a.mat.Inc), 0, 0) uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
} }
} }
} }
// AddVec adds the vectors a and b, placing the result in the receiver. // AddVec adds the vectors a and b, placing the result in the receiver.
func (v *VecDense) AddVec(a, b *VecDense) { func (v *VecDense) AddVec(a, b Vector) {
ar := a.Len() ar := a.Len()
br := b.Len() br := b.Len()
@@ -285,27 +325,42 @@ func (v *VecDense) AddVec(a, b *VecDense) {
panic(ErrShape) panic(ErrShape)
} }
if v != a {
v.checkOverlap(a.mat)
}
if v != b {
v.checkOverlap(b.mat)
}
v.reuseAs(ar) v.reuseAs(ar)
if v.mat.Inc == 1 && a.mat.Inc == 1 && b.mat.Inc == 1 { aU, _ := untranspose(a)
bU, _ := untranspose(b)
if arv, ok := aU.(RawVectorer); ok {
if brv, ok := bU.(RawVectorer); ok {
amat := arv.RawVector()
bmat := brv.RawVector()
if v != a {
v.checkOverlap(amat)
}
if v != b {
v.checkOverlap(bmat)
}
if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case. // Fast path for a common case.
f64.AxpyUnitaryTo(v.mat.Data, 1, b.mat.Data, a.mat.Data) f64.AxpyUnitaryTo(v.mat.Data, 1, bmat.Data, amat.Data)
return return
} }
f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0, f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
1, b.mat.Data, a.mat.Data, 1, bmat.Data, amat.Data,
uintptr(ar), uintptr(b.mat.Inc), uintptr(a.mat.Inc), 0, 0) uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
return
}
}
for i := 0; i < ar; i++ {
v.setVec(i, a.AtVec(i)+b.AtVec(i))
}
} }
// SubVec subtracts the vector b from a, placing the result in the receiver. // SubVec subtracts the vector b from a, placing the result in the receiver.
func (v *VecDense) SubVec(a, b *VecDense) { func (v *VecDense) SubVec(a, b Vector) {
ar := a.Len() ar := a.Len()
br := b.Len() br := b.Len()
@@ -313,28 +368,43 @@ func (v *VecDense) SubVec(a, b *VecDense) {
panic(ErrShape) panic(ErrShape)
} }
if v != a {
v.checkOverlap(a.mat)
}
if v != b {
v.checkOverlap(b.mat)
}
v.reuseAs(ar) v.reuseAs(ar)
if v.mat.Inc == 1 && a.mat.Inc == 1 && b.mat.Inc == 1 { aU, _ := untranspose(a)
bU, _ := untranspose(b)
if arv, ok := aU.(RawVectorer); ok {
if brv, ok := bU.(RawVectorer); ok {
amat := arv.RawVector()
bmat := brv.RawVector()
if v != a {
v.checkOverlap(amat)
}
if v != b {
v.checkOverlap(bmat)
}
if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case. // Fast path for a common case.
f64.AxpyUnitaryTo(v.mat.Data, -1, b.mat.Data, a.mat.Data) f64.AxpyUnitaryTo(v.mat.Data, -1, bmat.Data, amat.Data)
return return
} }
f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0, f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
-1, b.mat.Data, a.mat.Data, -1, bmat.Data, amat.Data,
uintptr(ar), uintptr(b.mat.Inc), uintptr(a.mat.Inc), 0, 0) uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
return
}
}
for i := 0; i < ar; i++ {
v.setVec(i, a.AtVec(i)-b.AtVec(i))
}
} }
// MulElemVec performs element-wise multiplication of a and b, placing the result // MulElemVec performs element-wise multiplication of a and b, placing the result
// in the receiver. // in the receiver.
func (v *VecDense) MulElemVec(a, b *VecDense) { func (v *VecDense) MulElemVec(a, b Vector) {
ar := a.Len() ar := a.Len()
br := b.Len() br := b.Len()
@@ -342,24 +412,48 @@ func (v *VecDense) MulElemVec(a, b *VecDense) {
panic(ErrShape) panic(ErrShape)
} }
if v != a {
v.checkOverlap(a.mat)
}
if v != b {
v.checkOverlap(b.mat)
}
v.reuseAs(ar) v.reuseAs(ar)
amat, bmat := a.RawVector(), b.RawVector() aU, _ := untranspose(a)
for i := 0; i < v.n; i++ { bU, _ := untranspose(b)
v.mat.Data[i*v.mat.Inc] = amat.Data[i*amat.Inc] * bmat.Data[i*bmat.Inc]
if arv, ok := aU.(RawVectorer); ok {
if brv, ok := bU.(RawVectorer); ok {
amat := arv.RawVector()
bmat := brv.RawVector()
if v != a {
v.checkOverlap(amat)
}
if v != b {
v.checkOverlap(bmat)
}
if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case.
for i, a := range amat.Data {
v.mat.Data[i] = a * bmat.Data[i]
}
return
}
var ia, ib int
for i := 0; i < ar; i++ {
v.setVec(i, amat.Data[ia]*bmat.Data[ib])
ia += amat.Inc
ib += bmat.Inc
}
return
}
}
for i := 0; i < ar; i++ {
v.setVec(i, a.AtVec(i)*b.AtVec(i))
} }
} }
// DivElemVec performs element-wise division of a by b, placing the result // DivElemVec performs element-wise division of a by b, placing the result
// in the receiver. // in the receiver.
func (v *VecDense) DivElemVec(a, b *VecDense) { func (v *VecDense) DivElemVec(a, b Vector) {
ar := a.Len() ar := a.Len()
br := b.Len() br := b.Len()
@@ -367,88 +461,133 @@ func (v *VecDense) DivElemVec(a, b *VecDense) {
panic(ErrShape) panic(ErrShape)
} }
if v != a {
v.checkOverlap(a.mat)
}
if v != b {
v.checkOverlap(b.mat)
}
v.reuseAs(ar) v.reuseAs(ar)
amat, bmat := a.RawVector(), b.RawVector() aU, _ := untranspose(a)
for i := 0; i < v.n; i++ { bU, _ := untranspose(b)
v.mat.Data[i*v.mat.Inc] = amat.Data[i*amat.Inc] / bmat.Data[i*bmat.Inc]
if arv, ok := aU.(RawVectorer); ok {
if brv, ok := bU.(RawVectorer); ok {
amat := arv.RawVector()
bmat := brv.RawVector()
if v != a {
v.checkOverlap(amat)
}
if v != b {
v.checkOverlap(bmat)
}
if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case.
for i, a := range amat.Data {
v.setVec(i, a/bmat.Data[i])
}
return
}
var ia, ib int
for i := 0; i < ar; i++ {
v.setVec(i, amat.Data[ia]/bmat.Data[ib])
ia += amat.Inc
ib += bmat.Inc
}
}
}
for i := 0; i < ar; i++ {
v.setVec(i, a.AtVec(i)/b.AtVec(i))
} }
} }
// MulVec computes a * b. The result is stored into the receiver. // MulVec computes a * b. The result is stored into the receiver.
// MulVec panics if the number of columns in a does not equal the number of rows in b. // MulVec panics if the number of columns in a does not equal the number of rows in b
func (v *VecDense) MulVec(a Matrix, b *VecDense) { // or if the number of columns in b does not equal 1.
func (v *VecDense) MulVec(a Matrix, b Vector) {
r, c := a.Dims() r, c := a.Dims()
br := b.Len() br, bc := b.Dims()
if c != br { if c != br || bc != 1 {
panic(ErrShape) panic(ErrShape)
} }
aU, trans := untranspose(a)
var bmat blas64.Vector
fast := true
bU, _ := untranspose(b)
if rv, ok := bU.(RawVectorer); ok {
bmat = rv.RawVector()
if v != b { if v != b {
v.checkOverlap(b.mat) v.checkOverlap(bmat)
}
} else {
fast = false
} }
a, trans := untranspose(a)
ar, ac := a.Dims()
v.reuseAs(r) v.reuseAs(r)
var restore func() var restore func()
if v == a { if v == aU {
v, restore = v.isolatedWorkspace(a.(*VecDense)) v, restore = v.isolatedWorkspace(aU.(*VecDense))
defer restore() defer restore()
} else if v == b { } else if v == b {
v, restore = v.isolatedWorkspace(b) v, restore = v.isolatedWorkspace(b)
defer restore() defer restore()
} }
switch a := a.(type) { // TODO(kortschak): Improve the non-fast paths.
case *VecDense: switch aU := aU.(type) {
if v != a { case Vector:
v.checkOverlap(a.mat) if b.Len() == 1 {
// {n,1} x {1,1}
v.ScaleVec(b.AtVec(0), aU)
return
} }
if a.Len() == 1 { // {1,n} x {n,1}
// {1,1} x {1,n} if fast {
av := a.At(0, 0) if rv, ok := aU.(RawVectorer); ok {
for i := 0; i < b.Len(); i++ { amat := rv.RawVector()
v.mat.Data[i*v.mat.Inc] = av * b.mat.Data[i*b.mat.Inc] if v != aU {
v.checkOverlap(amat)
} }
if amat.Inc == 1 && bmat.Inc == 1 {
// Fast path for a common case.
v.setVec(0, f64.DotUnitary(amat.Data, bmat.Data))
return return
} }
if b.Len() == 1 { v.setVec(0, f64.DotInc(amat.Data, bmat.Data,
// {1,n} x {1,1} uintptr(c), uintptr(amat.Inc), uintptr(bmat.Inc), 0, 0))
bv := b.At(0, 0)
for i := 0; i < a.Len(); i++ {
v.mat.Data[i*v.mat.Inc] = bv * a.mat.Data[i*a.mat.Inc]
}
return return
} }
// {n,1} x {1,n} }
var sum float64 var sum float64
for i := 0; i < c; i++ { for i := 0; i < c; i++ {
sum += a.At(i, 0) * b.At(i, 0) sum += aU.AtVec(i) * b.AtVec(i)
} }
v.SetVec(0, sum) v.setVec(0, sum)
return return
case RawSymmetricer: case RawSymmetricer:
amat := a.RawSymmetric() if fast {
blas64.Symv(1, amat, b.mat, 0, v.mat) amat := aU.RawSymmetric()
// We don't know that a is a *SymDense, so make
// a temporary SymDense to check overlap.
(&SymDense{mat: amat}).checkOverlap(v.asGeneral())
blas64.Symv(1, amat, bmat, 0, v.mat)
return
}
case RawTriangular: case RawTriangular:
v.CopyVec(b) v.CopyVec(b)
amat := a.RawTriangular() amat := aU.RawTriangular()
// We don't know that a is a *TriDense, so make
// a temporary TriDense to check overlap.
(&TriDense{mat: amat}).checkOverlap(v.asGeneral())
ta := blas.NoTrans ta := blas.NoTrans
if trans { if trans {
ta = blas.Trans ta = blas.Trans
} }
blas64.Trmv(ta, amat, v.mat) blas64.Trmv(ta, amat, v.mat)
case RawMatrixer: case RawMatrixer:
amat := a.RawMatrix() if fast {
amat := aU.RawMatrix()
// We don't know that a is a *Dense, so make // We don't know that a is a *Dense, so make
// a temporary Dense to check overlap. // a temporary Dense to check overlap.
(&Dense{mat: amat}).checkOverlap(v.asGeneral()) (&Dense{mat: amat}).checkOverlap(v.asGeneral())
@@ -456,33 +595,28 @@ func (v *VecDense) MulVec(a Matrix, b *VecDense) {
if trans { if trans {
t = blas.Trans t = blas.Trans
} }
blas64.Gemv(t, 1, amat, b.mat, 0, v.mat) blas64.Gemv(t, 1, amat, bmat, 0, v.mat)
return
}
default: default:
if trans { if fast {
col := make([]float64, ar) for i := 0; i < r; i++ {
for c := 0; c < ac; c++ {
for i := range col {
col[i] = a.At(i, c)
}
var f float64 var f float64
for i, e := range col { for j := 0; j < c; j++ {
f += e * b.mat.Data[i*b.mat.Inc] f += a.At(i, j) * bmat.Data[j*bmat.Inc]
} }
v.mat.Data[c*v.mat.Inc] = f v.setVec(i, f)
} }
} else { return
row := make([]float64, ac)
for r := 0; r < ar; r++ {
for i := range row {
row[i] = a.At(r, i)
} }
}
for i := 0; i < r; i++ {
var f float64 var f float64
for i, e := range row { for j := 0; j < c; j++ {
f += e * b.mat.Data[i*b.mat.Inc] f += a.At(i, j) * b.AtVec(j)
}
v.mat.Data[r*v.mat.Inc] = f
}
} }
v.setVec(i, f)
} }
} }
@@ -510,7 +644,7 @@ func (v *VecDense) IsZero() bool {
return v.mat.Inc == 0 return v.mat.Inc == 0
} }
func (v *VecDense) isolatedWorkspace(a *VecDense) (n *VecDense, restore func()) { func (v *VecDense) isolatedWorkspace(a Vector) (n *VecDense, restore func()) {
l := a.Len() l := a.Len()
n = getWorkspaceVec(l, false) n = getWorkspaceVec(l, false)
return n, func() { return n, func() {

View File

@@ -184,10 +184,10 @@ func TestVecDenseAtSet(t *testing.T) {
func TestVecDenseMul(t *testing.T) { func TestVecDenseMul(t *testing.T) {
method := func(receiver, a, b Matrix) { method := func(receiver, a, b Matrix) {
type mulVecer interface { type mulVecer interface {
MulVec(a Matrix, b *VecDense) MulVec(a Matrix, b Vector)
} }
rd := receiver.(mulVecer) rd := receiver.(mulVecer)
rd.MulVec(a, b.(*VecDense)) rd.MulVec(a, b.(Vector))
} }
denseComparison := func(receiver, a, b *Dense) { denseComparison := func(receiver, a, b *Dense) {
receiver.Mul(a, b) receiver.Mul(a, b)
@@ -266,10 +266,10 @@ func TestVecDenseScale(t *testing.T) {
for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} { for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} {
method := func(receiver, a Matrix) { method := func(receiver, a Matrix) {
type scaleVecer interface { type scaleVecer interface {
ScaleVec(float64, *VecDense) ScaleVec(float64, Vector)
} }
v := receiver.(scaleVecer) v := receiver.(scaleVecer)
v.ScaleVec(alpha, a.(*VecDense)) v.ScaleVec(alpha, a.(Vector))
} }
denseComparison := func(receiver, a *Dense) { denseComparison := func(receiver, a *Dense) {
receiver.Scale(alpha, a) receiver.Scale(alpha, a)
@@ -282,10 +282,10 @@ func TestVecDenseAddScaled(t *testing.T) {
for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} { for _, alpha := range []float64{0, 1, -1, 2.3, -2.3} {
method := func(receiver, a, b Matrix) { method := func(receiver, a, b Matrix) {
type addScaledVecer interface { type addScaledVecer interface {
AddScaledVec(*VecDense, float64, *VecDense) AddScaledVec(Vector, float64, Vector)
} }
v := receiver.(addScaledVecer) v := receiver.(addScaledVecer)
v.AddScaledVec(a.(*VecDense), alpha, b.(*VecDense)) v.AddScaledVec(a.(Vector), alpha, b.(Vector))
} }
denseComparison := func(receiver, a, b *Dense) { denseComparison := func(receiver, a, b *Dense) {
var sb Dense var sb Dense