diff --git a/referenceblas/level2double.go b/referenceblas/level2double.go index 6e07fda7..f181b624 100644 --- a/referenceblas/level2double.go +++ b/referenceblas/level2double.go @@ -1,8 +1,11 @@ package referenceblas -import ( - "github.com/gonum/blas" -) +import "github.com/gonum/blas" + +// See http://www.netlib.org/lapack/explore-html/d4/de1/_l_i_c_e_n_s_e_source.html +// for more license information + +// TODO: Need to think about loops when doing row-major. Change after tests? const ( badOrder string = "referenceblas: illegal order" @@ -11,30 +14,9 @@ const ( badUplo string = "referenceblas: illegal triangularization" badTranspose string = "referenceblas: illegal transpose" badDiag string = "referenceblas: illegal diag" + badLda string = "lda must be lass than max(1,n)" ) -func getLevel2Indexes(o blas.Order, tA blas.Transpose, m, n, incX, incY int) (lenx, leny, kx, ky int) { - // Set up the lengths of the vectors and start up points - // TODO: Figure out how this works with order - lenx = m - leny = n - if tA == blas.NoTrans { - lenx = n - leny = m - } - if incX > 0 { - kx = 0 - } else { - kx = -(lenx - 1) * incX - } - if incY > 0 { - ky = 0 - } else { - ky = -(leny - 1) * incY - } - return lenx, leny, kx, ky -} - // Dgemv computes y = alpha*a*x + beta*y if tA = blas.NoTrans // or alpha*A^T*x + beta*y if tA = blas.Trans or blas.ConjTrans func (b Blas) Dgemv(o blas.Order, tA blas.Transpose, m, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { @@ -71,7 +53,24 @@ func (b Blas) Dgemv(o blas.Order, tA blas.Transpose, m, n int, alpha float64, a return } - _, lenY, kx, ky := getLevel2Indexes(o, tA, m, n, incX, incY) + // Set up indexes + lenX := m + lenY := n + if tA == blas.NoTrans { + lenX = n + lenY = m + } + var kx, ky int + if incX > 0 { + kx = 0 + } else { + kx = -(lenX - 1) * incX + } + if incY > 0 { + ky = 0 + } else { + ky = -(lenY - 1) * incY + } // First form y := beta * y b.Dscal(lenY, beta, y, incY) @@ -82,32 +81,57 @@ func (b Blas) Dgemv(o blas.Order, tA blas.Transpose, m, n int, alpha float64, a // Form y := alpha * A * x + y switch { - case tA == blas.NoTrans && o == blas.RowMajor: + + default: + panic("shouldn't be here") + + case o == blas.RowMajor && tA == blas.NoTrans: + iy := ky + for i := 0; i < m; i++ { + jx := kx + var temp float64 + for j := 0; j < n; j++ { + temp += a[lda*i+j] * x[jx] + jx += incX + } + y[iy] += alpha * temp + iy += incY + } + case o == blas.RowMajor && (tA == blas.Trans || tA == blas.ConjTrans): + ix := kx + for i := 0; i < m; i++ { + jy := ky + tmp := alpha * x[ix] + for j := 0; j < n; j++ { + y[jy] += a[lda*i+j] * tmp + jy += incY + } + ix += incX + } + + case o == blas.ColMajor && tA == blas.NoTrans: jx := kx for j := 0; j < n; j++ { temp := alpha * x[jx] iy := ky for i := 0; i < m; i++ { - y[iy] += temp * a[lda*i+j] + y[iy] += temp * a[lda*j+i] iy += incY } jx += incX } - case (tA == blas.Trans || tA == blas.ConjTrans) && o == blas.RowMajor: + case o == blas.ColMajor && (tA == blas.Trans || tA == blas.ConjTrans): jy := ky for j := 0; j < n; j++ { var temp float64 ix := kx for i := 0; i < m; i++ { - temp += a[lda*i+j] * x[ix] + temp += a[lda*j+i] * x[ix] ix += incX } y[jy] += alpha * temp jy += incY } - default: - // TODO: Add in other switch cases - panic("Not yet implemented") } } @@ -115,7 +139,7 @@ func (b Blas) Dgemv(o blas.Order, tA blas.Transpose, m, n int, alpha float64, a // x := A*x, or x := A**T*x, // where x is an n element vector and A is an n by n unit, or non-unit, // upper or lower triangular matrix. -func Dtrmv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { +func (Blas) Dtrmv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { // Verify inputs if o != blas.RowMajor && o != blas.ColMajor { panic(badOrder) @@ -146,6 +170,8 @@ func Dtrmv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a kx = -(n - 1) * incX } switch { + default: + panic("not yet implemented") case o == blas.RowMajor && tA == blas.NoTrans && ul == blas.Upper: jx := kx for j := 0; j < n; j++ { @@ -207,8 +233,6 @@ func Dtrmv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a jx += incX } } - default: - panic("not yet implemented") } } @@ -219,7 +243,7 @@ func Dtrmv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a // // No test for singularity or near-singularity is included in this // routine. Such tests must be performed before calling this routine. -func Dtrsv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { +func (Blas) Dtrsv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a []float64, lda int, x []float64, incX int) { // Test the input parameters // Verify inputs if o != blas.RowMajor && o != blas.ColMajor { @@ -254,6 +278,8 @@ func Dtrsv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a } switch { + default: + panic("col major not yet coded") case o == blas.RowMajor && tA == blas.NoTrans && ul == blas.Upper: jx := kx + (n-1)*incX for j := n; j >= 0; j-- { @@ -320,6 +346,181 @@ func Dtrsv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a } } +// Dsymv performs the matrix-vector operation +// y := alpha*A*x + beta*y, +// where alpha and beta are scalars, x and y are n element vectors and +// A is an n by n symmetric matrix. +func (b Blas) Dsymv(o blas.Order, ul blas.Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) { + // Check inputs + if o != blas.RowMajor && o != blas.ColMajor { + panic(badOrder) + } + if ul != blas.Lower && ul != blas.Upper { + panic(badUplo) + } + if n < 0 { + panic(negativeN) + } + if lda > 1 && lda > n { + panic(badLda) + } + if incX == 0 { + panic(zeroInc) + } + if incY == 0 { + panic(zeroInc) + } + // Quick return if possible + if n == 0 || (alpha == 0 && beta == 1) { + return + } + + // Set up start points + var kx, ky int + if incX > 0 { + kx = 1 + } else { + kx = -(n - 1) * incX + } + if incY > 0 { + ky = 1 + } else { + ky = -(n - 1) * incY + } + + // Form y = beta * y + if beta != 1 { + b.Dscal(n, beta, y, incY) + } + + if alpha == 0 { + return + } + + // TODO: Need to think about changing the major and minor + // looping when row major (help with cache misses) + + // Form y = Ax + y + switch { + default: + panic("not yet coded") + case o == blas.RowMajor && ul == blas.Upper: + jx := kx + jy := ky + for j := 0; j < n; j++ { + tmp1 := alpha * x[jx] + var tmp2 float64 + ix := kx + iy := ky + for i := 0; i < j-2; i++ { + y[iy] += tmp1 * a[i*lda+j] + tmp2 += a[i*lda+j] * x[ix] + ix += incX + iy += incY + } + y[jy] += tmp1*a[j*lda+j] + alpha*tmp2 + jx += incX + jy += incY + } + case o == blas.RowMajor && ul == blas.Lower: + jx := kx + jy := ky + for j := 0; j < n; j++ { + tmp1 := alpha * x[jx] + var tmp2 float64 + y[jy] += tmp1 * a[j*lda+j] + ix := jx + iy := jy + for i := j; i < n; i++ { + ix += incX + iy += incY + y[iy] += tmp1 * a[i*lda+j] + tmp2 += a[i*lda+j] * x[ix] + } + y[jy] += alpha * tmp2 + jx += incX + jy += incY + } + } +} + +// Dger performs the rank 1 operation +// A := alpha*x*y**T + A, +// where alpha is a scalar, x is an m element vector, y is an n element +// vector and A is an m by n matrix. +func (Blas) Dger(o blas.Order, m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) { + // Check inputs + if o != blas.RowMajor && o != blas.ColMajor { + panic(badOrder) + } + if m < 0 { + panic("m < 0") + } + if n < 0 { + panic(negativeN) + } + if incX == 0 { + panic(zeroInc) + } + if incY == 0 { + panic(zeroInc) + } + if o == blas.RowMajor { + if lda > 1 && lda > n { + panic(badLda) + } + } else { + if lda > 1 && lda > m { + panic(badLda) + } + } + // Quick return if possible + if m == 0 || n == 0 || alpha == 0 { + return + } + + var jy, kx int + if incY > 0 { + jy = 1 + } else { + jy = -(n - 1) * incY + } + + if incY > 0 { + kx = 1 + } else { + kx = -(m - 1) * incX + } + + switch o { + default: + panic("should not be here") + case blas.RowMajor: + // TODO: Switch this to looping the other way + for j := 0; j < n; j++ { + if y[jy] != 0 { + tmp := alpha * y[jy] + ix := kx + for i := 0; i < m; i++ { + a[i*lda+j] += x[ix] * tmp + } + } + jy += incY + } + case blas.ColMajor: + for j := 0; j < n; j++ { + if y[jy] != 0 { + tmp := alpha * y[jy] + ix := kx + for i := 0; i < m; i++ { + a[j*lda+i] += x[ix] * tmp + } + } + jy += incY + } + } +} + /* // Level 2 routines. Dgbmv(o Order, tA Transpose, m, n, kL, kU int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) @@ -331,9 +532,6 @@ func Dtrsv(o blas.Order, ul blas.Uplo, tA blas.Transpose, d blas.Diag, n int, a Dspmv(o Order, ul Uplo, n int, alpha float64, ap []float64, x []float64, incX int, beta float64, y []float64, incY int) Dspr(o Order, ul Uplo, n int, alpha float64, x []float64, incX int, ap []float64) Dspr2(o Order, ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64) - - Dsymv(o Order, ul Uplo, n int, alpha float64, a []float64, lda int, x []float64, incX int, beta float64, y []float64, incY int) - Dger(o Order, m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) Dsyr(o Order, ul Uplo, n int, alpha float64, x []float64, incX int, a []float64, lda int) Dsyr2(o Order, ul Uplo, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int) */ diff --git a/testblas/level2double.go b/testblas/level2double.go index e9e23508..d0105f8b 100644 --- a/testblas/level2double.go +++ b/testblas/level2double.go @@ -9,35 +9,31 @@ import ( // throwPanic will throw unexpected panics if true, or will just report them as errors if false const throwPanic = true -type DoubleMatTwoVecCase struct { - Name string - m int - n int - A [][]float64 - o blas.Order - tA blas.Transpose - x []float64 - incX int - y []float64 - incY int - lda int - xCopy []float64 - yCopy []float64 - Panics bool +type DgemvCase struct { + Name string + m int + n int + A [][]float64 + tA blas.Transpose + x []float64 + incX int + y []float64 + incY int + xCopy []float64 + yCopy []float64 - DgemvCases []DgemvCase + Subcases []DgemvSubcase } -type DgemvCase struct { +type DgemvSubcase struct { alpha float64 beta float64 ans []float64 } -var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ +var DgemvCases []DgemvCase = []DgemvCase{ { - Name: "M_gt_N_Inc1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_gt_N_Inc1_NoTrans", tA: blas.NoTrans, m: 5, n: 3, @@ -48,14 +44,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {1, 1, 2}, {9, 2, 5}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3}, - y: []float64{7, 8, 9, 10, 11}, - lda: 3, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3}, + y: []float64{7, 8, 9, 10, 11}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 0, beta: 0, @@ -79,8 +73,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_gt_N_Inc1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_gt_N_Inc1_Trans", tA: blas.Trans, m: 5, n: 3, @@ -91,14 +84,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {1, 1, 2}, {9, 2, 5}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3, -4, 5}, - y: []float64{7, 8, 9}, - lda: 3, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3, -4, 5}, + y: []float64{7, 8, 9}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 0, beta: 0, @@ -122,8 +113,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_eq_N_Inc1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_eq_N_Inc1_NoTrans", tA: blas.NoTrans, m: 3, n: 3, @@ -132,14 +122,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1}, {10, 7, 3}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3}, - y: []float64{7, 2, 2}, - lda: 3, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3}, + y: []float64{7, 2, 2}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 0, beta: 0, @@ -163,8 +151,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_eq_N_Inc1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_eq_N_Inc1_Trans", tA: blas.Trans, m: 3, n: 3, @@ -173,14 +160,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1}, {10, 7, 3}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3}, - y: []float64{7, 2, 2}, - lda: 3, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3}, + y: []float64{7, 2, 2}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -189,8 +174,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_lt_N_Inc1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_lt_N_Inc1_NoTrans", tA: blas.NoTrans, m: 3, n: 5, @@ -199,14 +183,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1, -2, 9}, {10, 7, 3, 1, -5}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3, -7.6, 8.1}, - y: []float64{7, 2, 2}, - lda: 5, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3, -7.6, 8.1}, + y: []float64{7, 2, 2}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 0, beta: 0, @@ -231,8 +213,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_lt_N_Inc1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_lt_N_Inc1_Trans", tA: blas.Trans, m: 3, n: 5, @@ -241,14 +222,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1, -2, 9}, {10, 7, 3, 1, -5}, }, - incX: 1, - incY: 1, - x: []float64{1, 2, 3}, - y: []float64{7, 2, 2, -3, 5}, - lda: 5, - Panics: false, + incX: 1, + incY: 1, + x: []float64{1, 2, 3}, + y: []float64{7, 2, 2, -3, 5}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -257,8 +236,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_gt_N_IncNot1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_gt_N_IncNot1_NoTrans", tA: blas.NoTrans, m: 5, n: 3, @@ -270,13 +248,11 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {1, 1, 2}, {9, 2, 5}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3}, - y: []float64{7, 2, 6, 8, -4, -5, 9, 1, 1, 10, 19, 22, 11}, - lda: 3, - Panics: false, - DgemvCases: []DgemvCase{ + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3}, + y: []float64{7, 2, 6, 8, -4, -5, 9, 1, 1, 10, 19, 22, 11}, + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -285,8 +261,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_gt_N_IncNot1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_gt_N_IncNot1_Trans", tA: blas.Trans, m: 5, n: 3, @@ -298,13 +273,11 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {1, 1, 2}, {9, 2, 5}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3, 8, -3, 6, 5}, - y: []float64{7, 2, 6, 8, -4, -5, 9}, - lda: 3, - Panics: false, - DgemvCases: []DgemvCase{ + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3, 8, -3, 6, 5}, + y: []float64{7, 2, 6, 8, -4, -5, 9}, + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -313,8 +286,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_eq_N_IncNot1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_eq_N_IncNot1_NoTrans", tA: blas.NoTrans, m: 3, n: 3, @@ -323,14 +295,11 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1}, {10, 7, 3}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3}, - y: []float64{7, 2, 6, 8, -4, -5, 9}, - lda: 3, - Panics: false, - - DgemvCases: []DgemvCase{ + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3}, + y: []float64{7, 2, 6, 8, -4, -5, 9}, + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -339,8 +308,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_eq_N_IncNot1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_eq_N_IncNot1_Trans", tA: blas.Trans, m: 3, n: 3, @@ -349,14 +317,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1}, {10, 7, 3}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3}, - y: []float64{7, 2, 6, 8, -4, -5, 9}, - lda: 3, - Panics: false, + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3}, + y: []float64{7, 2, 6, 8, -4, -5, 9}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -365,8 +331,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_lt_N_IncNot1_RowMajor_NoTrans", - o: blas.RowMajor, + Name: "M_lt_N_IncNot1_NoTrans", tA: blas.NoTrans, m: 3, n: 5, @@ -375,14 +340,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1, -3, -2}, {10, 7, 3, -7, -4}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3, -2, -4, 8, -9}, - y: []float64{7, 2, 6, 8, -4, -5, 9}, - lda: 5, - Panics: false, + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3, -2, -4, 8, -9}, + y: []float64{7, 2, 6, 8, -4, -5, 9}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -391,8 +354,7 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ }, }, { - Name: "M_lt_N_IncNot1_RowMajor_Trans", - o: blas.RowMajor, + Name: "M_lt_N_IncNot1_Trans", tA: blas.Trans, m: 3, n: 5, @@ -401,14 +363,12 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ {9.6, 3.5, 9.1, -3, -2}, {10, 7, 3, -7, -4}, }, - incX: 2, - incY: 3, - x: []float64{1, 15, 2, 150, 3}, - y: []float64{7, 2, 6, 8, -4, -5, 9, -4, -1, -9, 1, 1, 2}, - lda: 5, - Panics: false, + incX: 2, + incY: 3, + x: []float64{1, 15, 2, 150, 3}, + y: []float64{7, 2, 6, 8, -4, -5, 9, -4, -1, -9, 1, 1, 2}, - DgemvCases: []DgemvCase{ + Subcases: []DgemvSubcase{ { alpha: 8, beta: -6, @@ -419,8 +379,6 @@ var DoubleMatTwoVecCases []DoubleMatTwoVecCase = []DoubleMatTwoVecCase{ // TODO: A can be longer than mxn. Add cases where it is longer // TODO: x and y can also be longer. Add tests for these - // TODO: Add column major - // TODO: Add tests for all the bad inputs // TODO: Add tests for dimension mismatch // TODO: Add negative increments // TODO: Add places with a "submatrix view", where lda != m @@ -472,37 +430,116 @@ type Dgemver interface { } func DgemvTest(t *testing.T, blasser Dgemver) { - for _, test := range DoubleMatTwoVecCases { - for i, cas := range test.DgemvCases { - x := sliceCopy(test.x) - y := sliceCopy(test.y) - a := sliceOfSliceCopy(test.A) - aFlat := flatten(a, test.o) - f := func() { - blasser.Dgemv(test.o, test.tA, test.m, test.n, cas.alpha, aFlat, test.lda, x, test.incX, cas.beta, y, test.incY) - } - if panics(f) { - if !test.Panics { - t.Errorf("Test %v case %v unexpected panic", test.Name, i) - if throwPanic { - blasser.Dgemv(test.o, test.tA, test.m, test.n, cas.alpha, aFlat, test.lda, x, test.incX, cas.beta, y, test.incY) - } - } - continue - } - // Check that x and a are unchanged - if !dSliceEqual(x, test.x) { - t.Errorf("Test %v, case %v x modified during call", test.Name, i) - } - aFlat2 := flatten(sliceOfSliceCopy(test.A), test.o) - if !dSliceEqual(aFlat2, aFlat) { - t.Errorf("Test %v, case %v a modified during call", test.Name, i) - } + for _, test := range DgemvCases { + for i, cas := range test.Subcases { + // Test that it passes with row-major + dgemvcomp(t, blas.RowMajor, test, cas, i, blasser) - // Check that the answer matches - if !dSliceTolEqual(cas.ans, y) { - t.Errorf("Test %v, case %v answer mismatch: Expected %v, Found %v", test.Name, i, cas.ans, y) - } + // Test that it passes with col-major + dgemvcomp(t, blas.ColMajor, test, cas, i, blasser) + + // Test the bad inputs + dgemvbad(t, test, cas, i, blasser) } } } + +func dgemvcomp(t *testing.T, o blas.Order, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) { + x := sliceCopy(test.x) + y := sliceCopy(test.y) + a := sliceOfSliceCopy(test.A) + aFlat := flatten(a, o) + + var lda int + if o == blas.RowMajor { + lda = test.n + } else if o == blas.ColMajor { + lda = test.m + } else { + panic("bad order") + } + + f := func() { + blasser.Dgemv(o, test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, test.incX, cas.beta, y, test.incY) + } + if panics(f) { + t.Errorf("Test %v case %v order %v unexpected panic", test.Name, i, o) + if throwPanic { + blasser.Dgemv(o, test.tA, test.m, test.n, cas.alpha, aFlat, lda, x, test.incX, cas.beta, y, test.incY) + } + + return + } + // Check that x and a are unchanged + if !dSliceEqual(x, test.x) { + t.Errorf("Test %v, case %v order %v: x modified during call", test.Name, i, o) + } + aFlat2 := flatten(sliceOfSliceCopy(test.A), o) + if !dSliceEqual(aFlat2, aFlat) { + t.Errorf("Test %v, case %v order %v: a modified during call", test.Name, i, o) + } + + // Check that the answer matches + if !dSliceTolEqual(cas.ans, y) { + t.Errorf("Test %v, case %v order %v: answer mismatch: Expected %v, Found %v", test.Name, i, o, cas.ans, y) + } +} + +func dgemvbad(t *testing.T, test DgemvCase, cas DgemvSubcase, i int, blasser Dgemver) { + x := sliceCopy(test.x) + y := sliceCopy(test.y) + a := sliceOfSliceCopy(test.A) + aFlatRow := flatten(a, blas.RowMajor) + ldaRow := test.n + aFlatCol := flatten(a, blas.ColMajor) + ldaCol := test.m + // Test that panics on bad order + f := func() { + blasser.Dgemv(312, test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for bad order", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, 312, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for bad transpose", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, -2, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for m negative", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, test.m, -4, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for n negative", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, 0, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for incX zero", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow, x, test.incX, cas.beta, y, 0) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for incY zero", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, test.m, test.n, cas.alpha, aFlatRow, ldaRow+3, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for lda too large row", test.Name, i) + } + f = func() { + blasser.Dgemv(blas.RowMajor, test.tA, test.m, test.n, cas.alpha, aFlatCol, ldaCol+3, x, test.incX, cas.beta, y, test.incY) + } + if !panics(f) { + t.Errorf("Test %v case %v: no panic for lda too large col", test.Name, i) + } +}