mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
259 lines
5.6 KiB
Go
259 lines
5.6 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
||
// Use of this source code is governed by a BSD-style
|
||
// license that can be found in the LICENSE file.
|
||
|
||
package mat
|
||
|
||
import (
|
||
"fmt"
|
||
"testing"
|
||
|
||
"golang.org/x/exp/rand"
|
||
)
|
||
|
||
type dims struct{ r, c int }
|
||
|
||
var productTests = []struct {
|
||
n int
|
||
factors []dims
|
||
product dims
|
||
panics bool
|
||
}{
|
||
{
|
||
n: 1,
|
||
factors: []dims{{3, 4}},
|
||
product: dims{3, 4},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 1,
|
||
factors: []dims{{2, 4}},
|
||
product: dims{3, 4},
|
||
panics: true,
|
||
},
|
||
{
|
||
n: 3,
|
||
factors: []dims{{10, 30}, {30, 5}, {5, 60}},
|
||
product: dims{10, 60},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 3,
|
||
factors: []dims{{100, 30}, {30, 5}, {5, 60}},
|
||
product: dims{10, 60},
|
||
panics: true,
|
||
},
|
||
{
|
||
n: 7,
|
||
factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
|
||
product: dims{60, 10},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 7,
|
||
factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
|
||
product: dims{60, 10},
|
||
panics: true,
|
||
},
|
||
{
|
||
n: 3,
|
||
factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
|
||
product: dims{1, 2},
|
||
panics: false,
|
||
},
|
||
|
||
// Random chains.
|
||
{
|
||
n: 0,
|
||
product: dims{0, 0},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 2,
|
||
product: dims{60, 10},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 3,
|
||
product: dims{60, 10},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 4,
|
||
product: dims{60, 10},
|
||
panics: false,
|
||
},
|
||
{
|
||
n: 10,
|
||
product: dims{60, 10},
|
||
panics: false,
|
||
},
|
||
}
|
||
|
||
func TestProduct(t *testing.T) {
|
||
t.Parallel()
|
||
rnd := rand.New(rand.NewSource(1))
|
||
for _, test := range productTests {
|
||
dimensions := test.factors
|
||
if dimensions == nil && test.n > 0 {
|
||
dimensions = make([]dims, test.n)
|
||
for i := range dimensions {
|
||
if i != 0 {
|
||
dimensions[i].r = dimensions[i-1].c
|
||
}
|
||
dimensions[i].c = rnd.Intn(50) + 1
|
||
}
|
||
dimensions[0].r = test.product.r
|
||
dimensions[test.n-1].c = test.product.c
|
||
}
|
||
factors := make([]Matrix, test.n)
|
||
for i, d := range dimensions {
|
||
data := make([]float64, d.r*d.c)
|
||
for i := range data {
|
||
data[i] = rnd.Float64()
|
||
}
|
||
factors[i] = NewDense(d.r, d.c, data)
|
||
}
|
||
|
||
want := &Dense{}
|
||
if !test.panics {
|
||
var a *Dense
|
||
for i, b := range factors {
|
||
if i == 0 {
|
||
want.CloneFrom(b)
|
||
continue
|
||
}
|
||
a, want = want, &Dense{}
|
||
want.Mul(a, b)
|
||
}
|
||
}
|
||
|
||
got := &Dense{}
|
||
if test.product.r != 0 && test.product.c != 0 {
|
||
got = NewDense(test.product.r, test.product.c, nil)
|
||
}
|
||
panicked, message := panics(func() {
|
||
got.Product(factors...)
|
||
})
|
||
if test.panics {
|
||
if !panicked {
|
||
t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
|
||
dimensions, test.product)
|
||
}
|
||
continue
|
||
} else if panicked {
|
||
t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
|
||
message, dimensions, test.product)
|
||
continue
|
||
}
|
||
|
||
if len(factors) > 0 {
|
||
p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
|
||
p.optimize()
|
||
gotCost := p.table.at(0, len(factors)-1).cost
|
||
expr, wantCost, ok := bestExpressionFor(dimensions)
|
||
if !ok {
|
||
t.Fatal("unexpected number of expressions in brute force expression search")
|
||
}
|
||
if gotCost != wantCost {
|
||
t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
|
||
dimensions, got, want, expr)
|
||
}
|
||
}
|
||
|
||
if !EqualApprox(got, want, 1e-14) {
|
||
t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
|
||
}
|
||
}
|
||
}
|
||
|
||
// node is a subexpression node.
|
||
type node struct {
|
||
dims
|
||
left, right *node
|
||
}
|
||
|
||
func (n *node) String() string {
|
||
if n.left == nil || n.right == nil {
|
||
rows, cols := n.shape()
|
||
return fmt.Sprintf("[%d×%d]", rows, cols)
|
||
}
|
||
rows, cols := n.shape()
|
||
return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
|
||
}
|
||
|
||
// shape returns the dimensions of the result of the subexpression.
|
||
func (n *node) shape() (rows, cols int) {
|
||
if n.left == nil || n.right == nil {
|
||
return n.r, n.c
|
||
}
|
||
rows, _ = n.left.shape()
|
||
_, cols = n.right.shape()
|
||
return rows, cols
|
||
}
|
||
|
||
// cost returns the cost to evaluate the subexpression.
|
||
func (n *node) cost() int {
|
||
if n.left == nil || n.right == nil {
|
||
return 0
|
||
}
|
||
lr, lc := n.left.shape()
|
||
_, rc := n.right.shape()
|
||
return lr*lc*rc + n.left.cost() + n.right.cost()
|
||
}
|
||
|
||
// expressionsFor returns a channel that can be used to iterate over all
|
||
// expressions of the given factor dimensions.
|
||
func expressionsFor(factors []dims) chan *node {
|
||
if len(factors) == 1 {
|
||
c := make(chan *node, 1)
|
||
c <- &node{dims: factors[0]}
|
||
close(c)
|
||
return c
|
||
}
|
||
c := make(chan *node)
|
||
go func() {
|
||
for i := 1; i < len(factors); i++ {
|
||
for left := range expressionsFor(factors[:i]) {
|
||
for right := range expressionsFor(factors[i:]) {
|
||
c <- &node{left: left, right: right}
|
||
}
|
||
}
|
||
}
|
||
close(c)
|
||
}()
|
||
return c
|
||
}
|
||
|
||
// catalan returns the nth 0-based Catalan number.
|
||
func catalan(n int) int {
|
||
// Work in 64-bit integers since we overflow 32-bits for some tests.
|
||
p := int64(1)
|
||
for k := n + 1; k < 2*n+1; k++ {
|
||
p *= int64(k)
|
||
}
|
||
for k := 2; k < n+2; k++ {
|
||
p /= int64(k)
|
||
}
|
||
return int(p)
|
||
}
|
||
|
||
// bestExpressonFor returns the lowest cost expression for the given expression
|
||
// factor dimensions, the cost of the expression and whether the number of
|
||
// expressions searched matches the Catalan number for the number of factors.
|
||
func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
|
||
const maxInt = int(^uint(0) >> 1)
|
||
min := maxInt
|
||
var best *node
|
||
var n int
|
||
for exp := range expressionsFor(factors) {
|
||
n++
|
||
cost := exp.cost()
|
||
if cost < min {
|
||
min = cost
|
||
best = exp
|
||
}
|
||
}
|
||
return best, min, n == catalan(len(factors)-1)
|
||
}
|