mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
232 lines
4.5 KiB
Go
232 lines
4.5 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package mat
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/blas/blas64"
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
// TODO: Need to add tests where one is overwritten.
|
|
func TestMulTypes(t *testing.T) {
|
|
t.Parallel()
|
|
src := rand.NewSource(1)
|
|
for _, test := range []struct {
|
|
ar int
|
|
ac int
|
|
br int
|
|
bc int
|
|
Panics bool
|
|
}{
|
|
{
|
|
ar: 5,
|
|
ac: 5,
|
|
br: 5,
|
|
bc: 5,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 10,
|
|
ac: 5,
|
|
br: 5,
|
|
bc: 3,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 10,
|
|
ac: 5,
|
|
br: 5,
|
|
bc: 8,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 8,
|
|
ac: 10,
|
|
br: 10,
|
|
bc: 3,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 8,
|
|
ac: 3,
|
|
br: 3,
|
|
bc: 10,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 5,
|
|
ac: 8,
|
|
br: 8,
|
|
bc: 10,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 5,
|
|
ac: 12,
|
|
br: 12,
|
|
bc: 8,
|
|
Panics: false,
|
|
},
|
|
{
|
|
ar: 5,
|
|
ac: 7,
|
|
br: 8,
|
|
bc: 10,
|
|
Panics: true,
|
|
},
|
|
} {
|
|
ar := test.ar
|
|
ac := test.ac
|
|
br := test.br
|
|
bc := test.bc
|
|
|
|
// Generate random matrices
|
|
avec := make([]float64, ar*ac)
|
|
randomSlice(avec, src)
|
|
a := NewDense(ar, ac, avec)
|
|
|
|
bvec := make([]float64, br*bc)
|
|
randomSlice(bvec, src)
|
|
|
|
b := NewDense(br, bc, bvec)
|
|
|
|
// Check that it panics if it is supposed to
|
|
if test.Panics {
|
|
c := &Dense{}
|
|
fn := func() {
|
|
c.Mul(a, b)
|
|
}
|
|
pan, _ := panics(fn)
|
|
if !pan {
|
|
t.Errorf("Mul did not panic with dimension mismatch")
|
|
}
|
|
continue
|
|
}
|
|
|
|
cvec := make([]float64, ar*bc)
|
|
|
|
// Get correct matrix multiply answer from blas64.Gemm
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans,
|
|
1, a.mat, b.mat,
|
|
0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec},
|
|
)
|
|
|
|
avecCopy := append([]float64{}, avec...)
|
|
bvecCopy := append([]float64{}, bvec...)
|
|
cvecCopy := append([]float64{}, cvec...)
|
|
|
|
acomp := matComp{r: ar, c: ac, data: avecCopy}
|
|
bcomp := matComp{r: br, c: bc, data: bvecCopy}
|
|
ccomp := matComp{r: ar, c: bc, data: cvecCopy}
|
|
|
|
// Do normal multiply with empty dense
|
|
d := &Dense{}
|
|
|
|
testMul(t, a, b, d, acomp, bcomp, ccomp, false, "empty receiver")
|
|
|
|
// Normal multiply with existing receiver
|
|
c := NewDense(ar, bc, cvec)
|
|
randomSlice(cvec, src)
|
|
testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver")
|
|
|
|
// Cast a as a basic matrix
|
|
am := (*basicMatrix)(a)
|
|
bm := (*basicMatrix)(b)
|
|
d.Reset()
|
|
testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is empty")
|
|
d.Reset()
|
|
testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is empty")
|
|
d.Reset()
|
|
testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is empty")
|
|
randomSlice(cvec, src)
|
|
testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full")
|
|
randomSlice(cvec, src)
|
|
testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full")
|
|
randomSlice(cvec, src)
|
|
testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full")
|
|
}
|
|
}
|
|
|
|
func randomSlice(s []float64, src rand.Source) {
|
|
rnd := rand.New(src)
|
|
for i := range s {
|
|
s[i] = rnd.NormFloat64()
|
|
}
|
|
}
|
|
|
|
type matComp struct {
|
|
r, c int
|
|
data []float64
|
|
}
|
|
|
|
func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, cvecApprox bool, name string) {
|
|
c.Mul(a, b)
|
|
var aDense *Dense
|
|
switch t := a.(type) {
|
|
case *Dense:
|
|
aDense = t
|
|
case *basicMatrix:
|
|
aDense = (*Dense)(t)
|
|
}
|
|
|
|
var bDense *Dense
|
|
switch t := b.(type) {
|
|
case *Dense:
|
|
bDense = t
|
|
case *basicMatrix:
|
|
bDense = (*Dense)(t)
|
|
}
|
|
|
|
if !denseEqual(aDense, acomp) {
|
|
t.Errorf("a changed unexpectedly for %v", name)
|
|
}
|
|
if !denseEqual(bDense, bcomp) {
|
|
t.Errorf("b changed unexpectedly for %v", name)
|
|
}
|
|
if cvecApprox {
|
|
if !denseEqualApprox(c, ccomp, 1e-14) {
|
|
t.Errorf("mul answer not within tol for %v", name)
|
|
}
|
|
return
|
|
}
|
|
|
|
if !denseEqual(c, ccomp) {
|
|
t.Errorf("mul answer not equal for %v", name)
|
|
}
|
|
}
|
|
|
|
func denseEqual(a *Dense, acomp matComp) bool {
|
|
ar2, ac2 := a.Dims()
|
|
if ar2 != acomp.r {
|
|
return false
|
|
}
|
|
if ac2 != acomp.c {
|
|
return false
|
|
}
|
|
if !floats.Equal(a.mat.Data, acomp.data) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func denseEqualApprox(a *Dense, acomp matComp, tol float64) bool {
|
|
ar2, ac2 := a.Dims()
|
|
if ar2 != acomp.r {
|
|
return false
|
|
}
|
|
if ac2 != acomp.c {
|
|
return false
|
|
}
|
|
if !floats.EqualApprox(a.mat.Data, acomp.data, tol) {
|
|
return false
|
|
}
|
|
return true
|
|
}
|