mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-31 18:42:45 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			162 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			162 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2015 The gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package testlapack
 | |
| 
 | |
| import (
 | |
| 	"math/rand"
 | |
| 	"testing"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/blas/blas64"
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| 	"gonum.org/v1/gonum/lapack"
 | |
| )
 | |
| 
 | |
| type Dlarfber interface {
 | |
| 	Dlarfter
 | |
| 	Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct,
 | |
| 		store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int,
 | |
| 		c []float64, ldc int, work []float64, ldwork int)
 | |
| }
 | |
| 
 | |
| func DlarfbTest(t *testing.T, impl Dlarfber) {
 | |
| 	rnd := rand.New(rand.NewSource(1))
 | |
| 	for _, store := range []lapack.StoreV{lapack.ColumnWise, lapack.RowWise} {
 | |
| 		for _, direct := range []lapack.Direct{lapack.Forward, lapack.Backward} {
 | |
| 			for _, side := range []blas.Side{blas.Left, blas.Right} {
 | |
| 				for _, trans := range []blas.Transpose{blas.Trans, blas.NoTrans} {
 | |
| 					for cas, test := range []struct {
 | |
| 						ma, na, cdim, lda, ldt, ldc int
 | |
| 					}{
 | |
| 						{6, 6, 6, 0, 0, 0},
 | |
| 						{6, 8, 10, 0, 0, 0},
 | |
| 						{6, 10, 8, 0, 0, 0},
 | |
| 						{8, 6, 10, 0, 0, 0},
 | |
| 						{8, 10, 6, 0, 0, 0},
 | |
| 						{10, 6, 8, 0, 0, 0},
 | |
| 						{10, 8, 6, 0, 0, 0},
 | |
| 						{6, 6, 6, 12, 15, 30},
 | |
| 						{6, 8, 10, 12, 15, 30},
 | |
| 						{6, 10, 8, 12, 15, 30},
 | |
| 						{8, 6, 10, 12, 15, 30},
 | |
| 						{8, 10, 6, 12, 15, 30},
 | |
| 						{10, 6, 8, 12, 15, 30},
 | |
| 						{10, 8, 6, 12, 15, 30},
 | |
| 						{6, 6, 6, 15, 12, 30},
 | |
| 						{6, 8, 10, 15, 12, 30},
 | |
| 						{6, 10, 8, 15, 12, 30},
 | |
| 						{8, 6, 10, 15, 12, 30},
 | |
| 						{8, 10, 6, 15, 12, 30},
 | |
| 						{10, 6, 8, 15, 12, 30},
 | |
| 						{10, 8, 6, 15, 12, 30},
 | |
| 					} {
 | |
| 						// Generate a matrix for QR
 | |
| 						ma := test.ma
 | |
| 						na := test.na
 | |
| 						lda := test.lda
 | |
| 						if lda == 0 {
 | |
| 							lda = na
 | |
| 						}
 | |
| 						a := make([]float64, ma*lda)
 | |
| 						for i := 0; i < ma; i++ {
 | |
| 							for j := 0; j < lda; j++ {
 | |
| 								a[i*lda+j] = rnd.Float64()
 | |
| 							}
 | |
| 						}
 | |
| 						k := min(ma, na)
 | |
| 
 | |
| 						// H is always ma x ma
 | |
| 						var m, n, rowsWork int
 | |
| 						switch {
 | |
| 						default:
 | |
| 							panic("not implemented")
 | |
| 						case side == blas.Left:
 | |
| 							m = test.ma
 | |
| 							n = test.cdim
 | |
| 							rowsWork = n
 | |
| 						case side == blas.Right:
 | |
| 							m = test.cdim
 | |
| 							n = test.ma
 | |
| 							rowsWork = m
 | |
| 						}
 | |
| 
 | |
| 						// Use dgeqr2 to find the v vectors
 | |
| 						tau := make([]float64, na)
 | |
| 						work := make([]float64, na)
 | |
| 						impl.Dgeqr2(ma, k, a, lda, tau, work)
 | |
| 
 | |
| 						// Correct the v vectors based on the direct and store
 | |
| 						vMatTmp := extractVMat(ma, na, a, lda, lapack.Forward, lapack.ColumnWise)
 | |
| 						vMat := constructVMat(vMatTmp, store, direct)
 | |
| 						v := vMat.Data
 | |
| 						ldv := vMat.Stride
 | |
| 
 | |
| 						// Use dlarft to find the t vector
 | |
| 						ldt := test.ldt
 | |
| 						if ldt == 0 {
 | |
| 							ldt = k
 | |
| 						}
 | |
| 						tm := make([]float64, k*ldt)
 | |
| 
 | |
| 						impl.Dlarft(direct, store, ma, k, v, ldv, tau, tm, ldt)
 | |
| 
 | |
| 						// Generate c matrix
 | |
| 						ldc := test.ldc
 | |
| 						if ldc == 0 {
 | |
| 							ldc = n
 | |
| 						}
 | |
| 						c := make([]float64, m*ldc)
 | |
| 						for i := 0; i < m; i++ {
 | |
| 							for j := 0; j < ldc; j++ {
 | |
| 								c[i*ldc+j] = rnd.Float64()
 | |
| 							}
 | |
| 						}
 | |
| 						cCopy := make([]float64, len(c))
 | |
| 						copy(cCopy, c)
 | |
| 
 | |
| 						ldwork := k
 | |
| 						work = make([]float64, rowsWork*k)
 | |
| 
 | |
| 						// Call Dlarfb with this information
 | |
| 						impl.Dlarfb(side, trans, direct, store, m, n, k, v, ldv, tm, ldt, c, ldc, work, ldwork)
 | |
| 
 | |
| 						h := constructH(tau, vMat, store, direct)
 | |
| 
 | |
| 						cMat := blas64.General{
 | |
| 							Rows:   m,
 | |
| 							Cols:   n,
 | |
| 							Stride: ldc,
 | |
| 							Data:   make([]float64, m*ldc),
 | |
| 						}
 | |
| 						copy(cMat.Data, cCopy)
 | |
| 						ans := blas64.General{
 | |
| 							Rows:   m,
 | |
| 							Cols:   n,
 | |
| 							Stride: ldc,
 | |
| 							Data:   make([]float64, m*ldc),
 | |
| 						}
 | |
| 						copy(ans.Data, cMat.Data)
 | |
| 						switch {
 | |
| 						default:
 | |
| 							panic("not implemented")
 | |
| 						case side == blas.Left && trans == blas.NoTrans:
 | |
| 							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, h, cMat, 0, ans)
 | |
| 						case side == blas.Left && trans == blas.Trans:
 | |
| 							blas64.Gemm(blas.Trans, blas.NoTrans, 1, h, cMat, 0, ans)
 | |
| 						case side == blas.Right && trans == blas.NoTrans:
 | |
| 							blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMat, h, 0, ans)
 | |
| 						case side == blas.Right && trans == blas.Trans:
 | |
| 							blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMat, h, 0, ans)
 | |
| 						}
 | |
| 						if !floats.EqualApprox(ans.Data, c, 1e-14) {
 | |
| 							t.Errorf("Cas %v mismatch. Want %v, got %v.", cas, ans.Data, c)
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
