mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-27 01:00:26 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			235 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			235 lines
		
	
	
		
			4.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package testblas
 | |
| 
 | |
| import (
 | |
| 	"math"
 | |
| 	"math/cmplx"
 | |
| 	"testing"
 | |
| 
 | |
| 	"golang.org/x/exp/rand"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| )
 | |
| 
 | |
| func TestFlattenBanded(t *testing.T) {
 | |
| 	for i, test := range []struct {
 | |
| 		dense     [][]float64
 | |
| 		ku        int
 | |
| 		kl        int
 | |
| 		condensed [][]float64
 | |
| 	}{
 | |
| 		{
 | |
| 			dense:     [][]float64{{3}},
 | |
| 			ku:        0,
 | |
| 			kl:        0,
 | |
| 			condensed: [][]float64{{3}},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{3, 4, 0},
 | |
| 			},
 | |
| 			ku: 1,
 | |
| 			kl: 0,
 | |
| 			condensed: [][]float64{
 | |
| 				{3, 4},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{3, 4, 0, 0, 0},
 | |
| 			},
 | |
| 			ku: 1,
 | |
| 			kl: 0,
 | |
| 			condensed: [][]float64{
 | |
| 				{3, 4},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{3, 4, 0},
 | |
| 				{0, 5, 8},
 | |
| 				{0, 0, 2},
 | |
| 				{0, 0, 0},
 | |
| 				{0, 0, 0},
 | |
| 			},
 | |
| 			ku: 1,
 | |
| 			kl: 0,
 | |
| 			condensed: [][]float64{
 | |
| 				{3, 4},
 | |
| 				{5, 8},
 | |
| 				{2, math.NaN()},
 | |
| 				{math.NaN(), math.NaN()},
 | |
| 				{math.NaN(), math.NaN()},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{3, 4, 6},
 | |
| 				{0, 5, 8},
 | |
| 				{0, 0, 2},
 | |
| 				{0, 0, 0},
 | |
| 				{0, 0, 0},
 | |
| 			},
 | |
| 			ku: 2,
 | |
| 			kl: 0,
 | |
| 			condensed: [][]float64{
 | |
| 				{3, 4, 6},
 | |
| 				{5, 8, math.NaN()},
 | |
| 				{2, math.NaN(), math.NaN()},
 | |
| 				{math.NaN(), math.NaN(), math.NaN()},
 | |
| 				{math.NaN(), math.NaN(), math.NaN()},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{3, 4, 6},
 | |
| 				{1, 5, 8},
 | |
| 				{0, 6, 2},
 | |
| 				{0, 0, 7},
 | |
| 				{0, 0, 0},
 | |
| 			},
 | |
| 			ku: 2,
 | |
| 			kl: 1,
 | |
| 			condensed: [][]float64{
 | |
| 				{math.NaN(), 3, 4, 6},
 | |
| 				{1, 5, 8, math.NaN()},
 | |
| 				{6, 2, math.NaN(), math.NaN()},
 | |
| 				{7, math.NaN(), math.NaN(), math.NaN()},
 | |
| 				{math.NaN(), math.NaN(), math.NaN(), math.NaN()},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{1, 2, 0},
 | |
| 				{3, 4, 5},
 | |
| 				{6, 7, 8},
 | |
| 				{0, 9, 10},
 | |
| 				{0, 0, 11},
 | |
| 			},
 | |
| 			ku: 1,
 | |
| 			kl: 2,
 | |
| 			condensed: [][]float64{
 | |
| 				{math.NaN(), math.NaN(), 1, 2},
 | |
| 				{math.NaN(), 3, 4, 5},
 | |
| 				{6, 7, 8, math.NaN()},
 | |
| 				{9, 10, math.NaN(), math.NaN()},
 | |
| 				{11, math.NaN(), math.NaN(), math.NaN()},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{1, 0, 0},
 | |
| 				{3, 4, 0},
 | |
| 				{6, 7, 8},
 | |
| 				{0, 9, 10},
 | |
| 				{0, 0, 11},
 | |
| 			},
 | |
| 			ku: 0,
 | |
| 			kl: 2,
 | |
| 			condensed: [][]float64{
 | |
| 				{math.NaN(), math.NaN(), 1},
 | |
| 				{math.NaN(), 3, 4},
 | |
| 				{6, 7, 8},
 | |
| 				{9, 10, math.NaN()},
 | |
| 				{11, math.NaN(), math.NaN()},
 | |
| 			},
 | |
| 		},
 | |
| 		{
 | |
| 			dense: [][]float64{
 | |
| 				{1, 0, 0, 0, 0},
 | |
| 				{3, 4, 0, 0, 0},
 | |
| 				{1, 3, 5, 0, 0},
 | |
| 			},
 | |
| 			ku: 0,
 | |
| 			kl: 2,
 | |
| 			condensed: [][]float64{
 | |
| 				{math.NaN(), math.NaN(), 1},
 | |
| 				{math.NaN(), 3, 4},
 | |
| 				{1, 3, 5},
 | |
| 			},
 | |
| 		},
 | |
| 	} {
 | |
| 		condensed := flattenBanded(test.dense, test.ku, test.kl)
 | |
| 		correct := flatten(test.condensed)
 | |
| 		if !floats.Same(condensed, correct) {
 | |
| 			t.Errorf("Case %v mismatch. Want %v, got %v.", i, correct, condensed)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestFlattenTriangular(t *testing.T) {
 | |
| 	for i, test := range []struct {
 | |
| 		a   [][]float64
 | |
| 		ans []float64
 | |
| 		ul  blas.Uplo
 | |
| 	}{
 | |
| 		{
 | |
| 			a: [][]float64{
 | |
| 				{1, 2, 3},
 | |
| 				{0, 4, 5},
 | |
| 				{0, 0, 6},
 | |
| 			},
 | |
| 			ul:  blas.Upper,
 | |
| 			ans: []float64{1, 2, 3, 4, 5, 6},
 | |
| 		},
 | |
| 		{
 | |
| 			a: [][]float64{
 | |
| 				{1, 0, 0},
 | |
| 				{2, 3, 0},
 | |
| 				{4, 5, 6},
 | |
| 			},
 | |
| 			ul:  blas.Lower,
 | |
| 			ans: []float64{1, 2, 3, 4, 5, 6},
 | |
| 		},
 | |
| 	} {
 | |
| 		a := flattenTriangular(test.a, test.ul)
 | |
| 		if !floats.Equal(a, test.ans) {
 | |
| 			t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func TestPackUnpackHermitian(t *testing.T) {
 | |
| 	rnd := rand.New(rand.NewSource(1))
 | |
| 	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
 | |
| 		for _, n := range []int{1, 2, 5, 50} {
 | |
| 			for _, lda := range []int{max(1, n), n + 11} {
 | |
| 				a := makeZGeneral(nil, n, n, lda)
 | |
| 				for i := 0; i < n; i++ {
 | |
| 					for j := i; j < n; j++ {
 | |
| 						a[i*lda+j] = complex(rnd.NormFloat64(), rnd.NormFloat64())
 | |
| 						if i != j {
 | |
| 							a[j*lda+i] = cmplx.Conj(a[i*lda+j])
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 				aCopy := make([]complex128, len(a))
 | |
| 				copy(aCopy, a)
 | |
| 
 | |
| 				ap := packHermitian(uplo, n, a, lda)
 | |
| 				if !zsame(a, aCopy) {
 | |
| 					t.Errorf("Case uplo=%v,n=%v,lda=%v: packHermitian modified a", uplo, n, lda)
 | |
| 				}
 | |
| 
 | |
| 				apCopy := make([]complex128, len(ap))
 | |
| 				copy(apCopy, ap)
 | |
| 
 | |
| 				art := unpackHermitian(uplo, n, ap)
 | |
| 				if !zsame(ap, apCopy) {
 | |
| 					t.Errorf("Case uplo=%v,n=%v,lda=%v: unpackHermitian modified ap", uplo, n, lda)
 | |
| 				}
 | |
| 
 | |
| 				// Copy the round-tripped A into a matrix with the same stride
 | |
| 				// as the original.
 | |
| 				got := makeZGeneral(nil, n, n, lda)
 | |
| 				for i := 0; i < n; i++ {
 | |
| 					copy(got[i*lda:i*lda+n], art[i*n:i*n+n])
 | |
| 				}
 | |
| 				if !zsame(got, a) {
 | |
| 					t.Errorf("Case uplo=%v,n=%v,lda=%v: packHermitian and unpackHermitian do not roundtrip", uplo, n, lda)
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | 
