mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-31 02:26:59 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			177 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			177 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2015 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package testlapack
 | |
| 
 | |
| import (
 | |
| 	"fmt"
 | |
| 	"testing"
 | |
| 
 | |
| 	"golang.org/x/exp/rand"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/blas/blas64"
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| 	"gonum.org/v1/gonum/lapack"
 | |
| )
 | |
| 
 | |
| type Dlarfer interface {
 | |
| 	Dlarf(side blas.Side, m, n int, v []float64, incv int, tau float64, c []float64, ldc int, work []float64)
 | |
| }
 | |
| 
 | |
| func DlarfTest(t *testing.T, impl Dlarfer) {
 | |
| 	for _, side := range []blas.Side{blas.Left, blas.Right} {
 | |
| 		name := sideToString(side)
 | |
| 		t.Run(name, func(t *testing.T) {
 | |
| 			runDlarfTest(t, impl, side)
 | |
| 		})
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func runDlarfTest(t *testing.T, impl Dlarfer, side blas.Side) {
 | |
| 	rnd := rand.New(rand.NewSource(1))
 | |
| 	for _, m := range []int{0, 1, 2, 3, 4, 5, 10} {
 | |
| 		for _, n := range []int{0, 1, 2, 3, 4, 5, 10} {
 | |
| 			for _, incv := range []int{1, 4} {
 | |
| 				for _, ldc := range []int{max(1, n), n + 3} {
 | |
| 					for _, nnzv := range []int{0, 1, 2} {
 | |
| 						for _, nnzc := range []int{0, 1, 2} {
 | |
| 							for _, tau := range []float64{0, rnd.NormFloat64()} {
 | |
| 								dlarfTest(t, impl, rnd, side, m, n, incv, ldc, nnzv, nnzc, tau)
 | |
| 							}
 | |
| 						}
 | |
| 					}
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func dlarfTest(t *testing.T, impl Dlarfer, rnd *rand.Rand, side blas.Side, m, n, incv, ldc, nnzv, nnzc int, tau float64) {
 | |
| 	const tol = 1e-14
 | |
| 
 | |
| 	c := make([]float64, m*ldc)
 | |
| 	for i := range c {
 | |
| 		c[i] = rnd.NormFloat64()
 | |
| 	}
 | |
| 	switch nnzc {
 | |
| 	case 0:
 | |
| 		// Zero out all of C.
 | |
| 		for i := 0; i < m; i++ {
 | |
| 			for j := 0; j < n; j++ {
 | |
| 				c[i*ldc+j] = 0
 | |
| 			}
 | |
| 		}
 | |
| 	case 1:
 | |
| 		// Zero out right or bottom half of C.
 | |
| 		if side == blas.Left {
 | |
| 			for i := 0; i < m; i++ {
 | |
| 				for j := n / 2; j < n; j++ {
 | |
| 					c[i*ldc+j] = 0
 | |
| 				}
 | |
| 			}
 | |
| 		} else {
 | |
| 			for i := m / 2; i < m; i++ {
 | |
| 				for j := 0; j < n; j++ {
 | |
| 					c[i*ldc+j] = 0
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 	default:
 | |
| 		// Leave C with random content.
 | |
| 	}
 | |
| 	cCopy := make([]float64, len(c))
 | |
| 	copy(cCopy, c)
 | |
| 
 | |
| 	var work []float64
 | |
| 	if side == blas.Left {
 | |
| 		work = make([]float64, n)
 | |
| 	} else {
 | |
| 		work = make([]float64, m)
 | |
| 	}
 | |
| 
 | |
| 	vlen := n
 | |
| 	if side == blas.Left {
 | |
| 		vlen = m
 | |
| 	}
 | |
| 	vlen = max(1, vlen)
 | |
| 	v := make([]float64, 1+(vlen-1)*incv)
 | |
| 	for i := range v {
 | |
| 		v[i] = rnd.NormFloat64()
 | |
| 	}
 | |
| 	switch nnzv {
 | |
| 	case 0:
 | |
| 		// Zero out all of v.
 | |
| 		for i := 0; i < vlen; i++ {
 | |
| 			v[i*incv] = 0
 | |
| 		}
 | |
| 	case 1:
 | |
| 		// Zero out half of v.
 | |
| 		for i := vlen / 2; i < vlen; i++ {
 | |
| 			v[i*incv] = 0
 | |
| 		}
 | |
| 	default:
 | |
| 		// Leave v with random content.
 | |
| 	}
 | |
| 	vCopy := make([]float64, len(v))
 | |
| 	copy(vCopy, v)
 | |
| 
 | |
| 	impl.Dlarf(side, m, n, v, incv, tau, c, ldc, work)
 | |
| 	got := c
 | |
| 
 | |
| 	name := fmt.Sprintf("m=%d,n=%d,incv=%d,tau=%f,ldc=%d", m, n, incv, tau, ldc)
 | |
| 
 | |
| 	if !floats.Equal(v, vCopy) {
 | |
| 		t.Errorf("%v: unexpected modification of v", name)
 | |
| 	}
 | |
| 	if tau == 0 && !floats.Equal(got, cCopy) {
 | |
| 		t.Errorf("%v: unexpected modification of C", name)
 | |
| 	}
 | |
| 
 | |
| 	if m == 0 || n == 0 || tau == 0 {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	bi := blas64.Implementation()
 | |
| 
 | |
| 	want := make([]float64, len(cCopy))
 | |
| 	if side == blas.Left {
 | |
| 		// Compute want = (I - tau * v * vᵀ) * C
 | |
| 
 | |
| 		// vtc = -tau * vᵀ * C = -tau * Cᵀ * v
 | |
| 		vtc := make([]float64, n)
 | |
| 		bi.Dgemv(blas.Trans, m, n, -tau, cCopy, ldc, v, incv, 0, vtc, 1)
 | |
| 
 | |
| 		// want = C + v * vtcᵀ
 | |
| 		for i := 0; i < m; i++ {
 | |
| 			for j := 0; j < n; j++ {
 | |
| 				want[i*ldc+j] = cCopy[i*ldc+j] + v[i*incv]*vtc[j]
 | |
| 			}
 | |
| 		}
 | |
| 	} else {
 | |
| 		// Compute want = C * (I - tau * v * vᵀ)
 | |
| 
 | |
| 		// cv = -tau * C * v
 | |
| 		cv := make([]float64, m)
 | |
| 		bi.Dgemv(blas.NoTrans, m, n, -tau, cCopy, ldc, v, incv, 0, cv, 1)
 | |
| 
 | |
| 		// want = C + cv * vᵀ
 | |
| 		for i := 0; i < m; i++ {
 | |
| 			for j := 0; j < n; j++ {
 | |
| 				want[i*ldc+j] = cCopy[i*ldc+j] + cv[i]*v[j*incv]
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	diff := make([]float64, m*n)
 | |
| 	for i := 0; i < m; i++ {
 | |
| 		for j := 0; j < n; j++ {
 | |
| 			diff[i*n+j] = got[i*ldc+j] - want[i*ldc+j]
 | |
| 		}
 | |
| 	}
 | |
| 	resid := dlange(lapack.MaxColumnSum, m, n, diff, n)
 | |
| 	if resid > tol*float64(max(m, n)) {
 | |
| 		t.Errorf("%v: unexpected result; resid=%v, want<=%v", name, resid, tol*float64(max(m, n)))
 | |
| 	}
 | |
| }
 | 
