mirror of
				https://github.com/gonum/gonum.git
				synced 2025-11-01 02:52:49 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			249 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			249 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2014 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package testblas
 | |
| 
 | |
| import (
 | |
| 	"testing"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| )
 | |
| 
 | |
| type Dgemmer interface {
 | |
| 	Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
 | |
| }
 | |
| 
 | |
| type DgemmCase struct {
 | |
| 	m, n, k     int
 | |
| 	alpha, beta float64
 | |
| 	a           [][]float64
 | |
| 	b           [][]float64
 | |
| 	c           [][]float64
 | |
| 	ans         [][]float64
 | |
| }
 | |
| 
 | |
| var DgemmCases = []DgemmCase{
 | |
| 
 | |
| 	{
 | |
| 		m:     4,
 | |
| 		n:     3,
 | |
| 		k:     2,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2},
 | |
| 			{4, 5},
 | |
| 			{7, 8},
 | |
| 			{10, 11},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5, 6},
 | |
| 			{5, -8, 8},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8, -9},
 | |
| 			{12, 16, -8},
 | |
| 			{1, 5, 15},
 | |
| 			{-3, -4, 7},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{24, -18, 39.5},
 | |
| 			{64, -32, 124},
 | |
| 			{94.5, -55.5, 219.5},
 | |
| 			{128.5, -78, 299.5},
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		m:     4,
 | |
| 		n:     2,
 | |
| 		k:     3,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2, 3},
 | |
| 			{4, 5, 6},
 | |
| 			{7, 8, 9},
 | |
| 			{10, 11, 12},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5},
 | |
| 			{5, -8},
 | |
| 			{6, 2},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8},
 | |
| 			{12, 16},
 | |
| 			{1, 5},
 | |
| 			{-3, -4},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{60, -6},
 | |
| 			{136, -8},
 | |
| 			{202.5, -19.5},
 | |
| 			{272.5, -30},
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		m:     3,
 | |
| 		n:     2,
 | |
| 		k:     4,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2, 3, 4},
 | |
| 			{4, 5, 6, 7},
 | |
| 			{8, 9, 10, 11},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5},
 | |
| 			{5, -8},
 | |
| 			{6, 2},
 | |
| 			{8, 10},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8},
 | |
| 			{12, 16},
 | |
| 			{9, -10},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{124, 74},
 | |
| 			{248, 132},
 | |
| 			{406.5, 191},
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		m:     3,
 | |
| 		n:     4,
 | |
| 		k:     2,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2},
 | |
| 			{4, 5},
 | |
| 			{8, 9},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5, 2, 1},
 | |
| 			{5, -8, 2, 1},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8, 2, 2},
 | |
| 			{12, 16, 8, 9},
 | |
| 			{9, -10, 10, 10},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{24, -18, 13, 7},
 | |
| 			{64, -32, 40, 22.5},
 | |
| 			{110.5, -69, 73, 39},
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		m:     2,
 | |
| 		n:     4,
 | |
| 		k:     3,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2, 3},
 | |
| 			{4, 5, 6},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5, 8, 8},
 | |
| 			{5, -8, 9, 10},
 | |
| 			{6, 2, -3, 2},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8, 7, 8},
 | |
| 			{12, 16, -2, 6},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{60, -6, 37.5, 72},
 | |
| 			{136, -8, 117, 191},
 | |
| 		},
 | |
| 	},
 | |
| 	{
 | |
| 		m:     2,
 | |
| 		n:     3,
 | |
| 		k:     4,
 | |
| 		alpha: 2,
 | |
| 		beta:  0.5,
 | |
| 		a: [][]float64{
 | |
| 			{1, 2, 3, 4},
 | |
| 			{4, 5, 6, 7},
 | |
| 		},
 | |
| 		b: [][]float64{
 | |
| 			{1, 5, 8},
 | |
| 			{5, -8, 9},
 | |
| 			{6, 2, -3},
 | |
| 			{8, 10, 2},
 | |
| 		},
 | |
| 		c: [][]float64{
 | |
| 			{4, 8, 1},
 | |
| 			{12, 16, 6},
 | |
| 		},
 | |
| 		ans: [][]float64{
 | |
| 			{124, 74, 50.5},
 | |
| 			{248, 132, 149},
 | |
| 		},
 | |
| 	},
 | |
| }
 | |
| 
 | |
| // assumes [][]float64 is actually a matrix
 | |
| func transpose(a [][]float64) [][]float64 {
 | |
| 	b := make([][]float64, len(a[0]))
 | |
| 	for i := range b {
 | |
| 		b[i] = make([]float64, len(a))
 | |
| 		for j := range b[i] {
 | |
| 			b[i][j] = a[j][i]
 | |
| 		}
 | |
| 	}
 | |
| 	return b
 | |
| }
 | |
| 
 | |
| func TestDgemm(t *testing.T, blasser Dgemmer) {
 | |
| 	for i, test := range DgemmCases {
 | |
| 		// Test that it passes row major
 | |
| 		dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
 | |
| 			test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
 | |
| 		// Try with A transposed
 | |
| 		dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
 | |
| 			test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
 | |
| 		// Try with B transposed
 | |
| 		dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
 | |
| 			test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
 | |
| 		// Try with both transposed
 | |
| 		dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
 | |
| 			test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
 | |
| 	alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
 | |
| 
 | |
| 	aFlat := flatten(a)
 | |
| 	aCopy := flatten(a)
 | |
| 	bFlat := flatten(b)
 | |
| 	bCopy := flatten(b)
 | |
| 	cFlat := flatten(c)
 | |
| 	ansFlat := flatten(ans)
 | |
| 	lda := len(a[0])
 | |
| 	ldb := len(b[0])
 | |
| 	ldc := len(c[0])
 | |
| 
 | |
| 	// Compute the matrix multiplication
 | |
| 	blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
 | |
| 
 | |
| 	if !dSliceEqual(aFlat, aCopy) {
 | |
| 		t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
 | |
| 	}
 | |
| 	if !dSliceEqual(bFlat, bCopy) {
 | |
| 		t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
 | |
| 	}
 | |
| 
 | |
| 	if !dSliceTolEqual(ansFlat, cFlat) {
 | |
| 		t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
 | |
| 	}
 | |
| 	// TODO: Need to add a sub-slice test where don't use up full matrix
 | |
| }
 | 
