mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-31 02:26:59 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			200 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			200 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2015 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package testlapack
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"sort"
 | |
| 	"testing"
 | |
| 
 | |
| 	"golang.org/x/exp/rand"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/blas/blas64"
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| 	"gonum.org/v1/gonum/floats/scalar"
 | |
| )
 | |
| 
 | |
| type Dbdsqrer interface {
 | |
| 	Dbdsqr(uplo blas.Uplo, n, ncvt, nru, ncc int, d, e, vt []float64, ldvt int, u []float64, ldu int, c []float64, ldc int, work []float64) (ok bool)
 | |
| }
 | |
| 
 | |
| func DbdsqrTest(t *testing.T, impl Dbdsqrer) {
 | |
| 	rnd := rand.New(rand.NewSource(1))
 | |
| 	bi := blas64.Implementation()
 | |
| 	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
 | |
| 		for _, test := range []struct {
 | |
| 			n, ncvt, nru, ncc, ldvt, ldu, ldc int
 | |
| 		}{
 | |
| 			{5, 5, 5, 5, 0, 0, 0},
 | |
| 			{10, 10, 10, 10, 0, 0, 0},
 | |
| 			{10, 11, 12, 13, 0, 0, 0},
 | |
| 			{20, 13, 12, 11, 0, 0, 0},
 | |
| 
 | |
| 			{5, 5, 5, 5, 6, 7, 8},
 | |
| 			{10, 10, 10, 10, 30, 40, 50},
 | |
| 			{10, 12, 11, 13, 30, 40, 50},
 | |
| 			{20, 12, 13, 11, 30, 40, 50},
 | |
| 
 | |
| 			{130, 130, 130, 500, 900, 900, 500},
 | |
| 		} {
 | |
| 			for cas := 0; cas < 10; cas++ {
 | |
| 				n := test.n
 | |
| 				ncvt := test.ncvt
 | |
| 				nru := test.nru
 | |
| 				ncc := test.ncc
 | |
| 				ldvt := test.ldvt
 | |
| 				ldu := test.ldu
 | |
| 				ldc := test.ldc
 | |
| 				if ldvt == 0 {
 | |
| 					ldvt = max(1, ncvt)
 | |
| 				}
 | |
| 				if ldu == 0 {
 | |
| 					ldu = max(1, n)
 | |
| 				}
 | |
| 				if ldc == 0 {
 | |
| 					ldc = max(1, ncc)
 | |
| 				}
 | |
| 
 | |
| 				d := make([]float64, n)
 | |
| 				for i := range d {
 | |
| 					d[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 				e := make([]float64, n-1)
 | |
| 				for i := range e {
 | |
| 					e[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 				dCopy := make([]float64, len(d))
 | |
| 				copy(dCopy, d)
 | |
| 				eCopy := make([]float64, len(e))
 | |
| 				copy(eCopy, e)
 | |
| 				work := make([]float64, 4*(n-1))
 | |
| 				for i := range work {
 | |
| 					work[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 
 | |
| 				// First test the decomposition of the bidiagonal matrix. Set
 | |
| 				// pt and u equal to I with the correct size. At the result
 | |
| 				// of Dbdsqr, p and u  will contain the data of Pᵀ and Q, which
 | |
| 				// will be used in the next step to test the multiplication
 | |
| 				// with Q and VT.
 | |
| 
 | |
| 				q := make([]float64, n*n)
 | |
| 				ldq := n
 | |
| 				pt := make([]float64, n*n)
 | |
| 				ldpt := n
 | |
| 				for i := 0; i < n; i++ {
 | |
| 					q[i*ldq+i] = 1
 | |
| 				}
 | |
| 				for i := 0; i < n; i++ {
 | |
| 					pt[i*ldpt+i] = 1
 | |
| 				}
 | |
| 
 | |
| 				ok := impl.Dbdsqr(uplo, n, n, n, 0, d, e, pt, ldpt, q, ldq, nil, 1, work)
 | |
| 
 | |
| 				isUpper := uplo == blas.Upper
 | |
| 				errStr := fmt.Sprintf("isUpper = %v, n = %v, ncvt = %v, nru = %v, ncc = %v", isUpper, n, ncvt, nru, ncc)
 | |
| 				if !ok {
 | |
| 					t.Errorf("Unexpected Dbdsqr failure: %s", errStr)
 | |
| 				}
 | |
| 
 | |
| 				bMat := constructBidiagonal(uplo, n, dCopy, eCopy)
 | |
| 				sMat := constructBidiagonal(uplo, n, d, e)
 | |
| 
 | |
| 				tmp := blas64.General{
 | |
| 					Rows:   n,
 | |
| 					Cols:   n,
 | |
| 					Stride: n,
 | |
| 					Data:   make([]float64, n*n),
 | |
| 				}
 | |
| 				ansMat := blas64.General{
 | |
| 					Rows:   n,
 | |
| 					Cols:   n,
 | |
| 					Stride: n,
 | |
| 					Data:   make([]float64, n*n),
 | |
| 				}
 | |
| 
 | |
| 				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, q, ldq, sMat.Data, sMat.Stride, 0, tmp.Data, tmp.Stride)
 | |
| 				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, n, n, 1, tmp.Data, tmp.Stride, pt, ldpt, 0, ansMat.Data, ansMat.Stride)
 | |
| 
 | |
| 				same := true
 | |
| 				for i := 0; i < n; i++ {
 | |
| 					for j := 0; j < n; j++ {
 | |
| 						if !scalar.EqualWithinAbsOrRel(ansMat.Data[i*ansMat.Stride+j], bMat.Data[i*bMat.Stride+j], 1e-8, 1e-8) {
 | |
| 							same = false
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 				if !same {
 | |
| 					t.Errorf("Bidiagonal mismatch. %s", errStr)
 | |
| 				}
 | |
| 				if !sort.IsSorted(sort.Reverse(sort.Float64Slice(d))) {
 | |
| 					t.Errorf("D is not sorted. %s", errStr)
 | |
| 				}
 | |
| 
 | |
| 				// The above computed the real P and Q. Now input data for Vᵀ,
 | |
| 				// U, and C to check that the multiplications happen properly.
 | |
| 				dAns := make([]float64, len(d))
 | |
| 				copy(dAns, d)
 | |
| 				eAns := make([]float64, len(e))
 | |
| 				copy(eAns, e)
 | |
| 
 | |
| 				u := make([]float64, nru*ldu)
 | |
| 				for i := range u {
 | |
| 					u[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 				uCopy := make([]float64, len(u))
 | |
| 				copy(uCopy, u)
 | |
| 				vt := make([]float64, n*ldvt)
 | |
| 				for i := range vt {
 | |
| 					vt[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 				vtCopy := make([]float64, len(vt))
 | |
| 				copy(vtCopy, vt)
 | |
| 				c := make([]float64, n*ldc)
 | |
| 				for i := range c {
 | |
| 					c[i] = rnd.NormFloat64()
 | |
| 				}
 | |
| 				cCopy := make([]float64, len(c))
 | |
| 				copy(cCopy, c)
 | |
| 
 | |
| 				// Reset input data
 | |
| 				copy(d, dCopy)
 | |
| 				copy(e, eCopy)
 | |
| 				impl.Dbdsqr(uplo, n, ncvt, nru, ncc, d, e, vt, ldvt, u, ldu, c, ldc, work)
 | |
| 
 | |
| 				// Check result.
 | |
| 				if !floats.EqualApprox(d, dAns, 1e-14) {
 | |
| 					t.Errorf("D mismatch second time. %s", errStr)
 | |
| 				}
 | |
| 				if !floats.EqualApprox(e, eAns, 1e-14) {
 | |
| 					t.Errorf("E mismatch second time. %s", errStr)
 | |
| 				}
 | |
| 				ans := make([]float64, len(vtCopy))
 | |
| 				copy(ans, vtCopy)
 | |
| 				ldans := ldvt
 | |
| 				bi.Dgemm(blas.NoTrans, blas.NoTrans, n, ncvt, n, 1, pt, ldpt, vtCopy, ldvt, 0, ans, ldans)
 | |
| 				if !floats.EqualApprox(ans, vt, 1e-10) {
 | |
| 					t.Errorf("Vt result mismatch. %s", errStr)
 | |
| 				}
 | |
| 				ans = make([]float64, len(uCopy))
 | |
| 				copy(ans, uCopy)
 | |
| 				ldans = ldu
 | |
| 				bi.Dgemm(blas.NoTrans, blas.NoTrans, nru, n, n, 1, uCopy, ldu, q, ldq, 0, ans, ldans)
 | |
| 				if !floats.EqualApprox(ans, u, 1e-10) {
 | |
| 					t.Errorf("U result mismatch. %s", errStr)
 | |
| 				}
 | |
| 				ans = make([]float64, len(cCopy))
 | |
| 				copy(ans, cCopy)
 | |
| 				ldans = ldc
 | |
| 				bi.Dgemm(blas.Trans, blas.NoTrans, n, ncc, n, 1, q, ldq, cCopy, ldc, 0, ans, ldans)
 | |
| 				if !floats.EqualApprox(ans, c, 1e-10) {
 | |
| 					t.Errorf("C result mismatch. %s", errStr)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
