mirror of
				https://github.com/gonum/gonum.git
				synced 2025-10-31 10:36:30 +08:00 
			
		
		
		
	 17ea55aedb
			
		
	
	17ea55aedb
	
	
	
		
			
			Apply (with manual curation after the fact):
* s/^T/U+1d40/g
* s/^H/U+1d34/g
* s/, {2,3}if / $1/g
Some additional manual editing of odd formatting.
		
	
		
			
				
	
	
		
			184 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			184 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright ©2015 The Gonum Authors. All rights reserved.
 | |
| // Use of this source code is governed by a BSD-style
 | |
| // license that can be found in the LICENSE file.
 | |
| 
 | |
| package testlapack
 | |
| 
 | |
| import (
 | |
| 	"testing"
 | |
| 
 | |
| 	"golang.org/x/exp/rand"
 | |
| 
 | |
| 	"gonum.org/v1/gonum/blas"
 | |
| 	"gonum.org/v1/gonum/blas/blas64"
 | |
| 	"gonum.org/v1/gonum/floats"
 | |
| )
 | |
| 
 | |
| type Dgelser interface {
 | |
| 	Dgels(trans blas.Transpose, m, n, nrhs int, a []float64, lda int, b []float64, ldb int, work []float64, lwork int) bool
 | |
| }
 | |
| 
 | |
| func DgelsTest(t *testing.T, impl Dgelser) {
 | |
| 	rnd := rand.New(rand.NewSource(1))
 | |
| 	for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
 | |
| 		for _, test := range []struct {
 | |
| 			m, n, nrhs, lda, ldb int
 | |
| 		}{
 | |
| 			{3, 4, 5, 0, 0},
 | |
| 			{3, 5, 4, 0, 0},
 | |
| 			{4, 3, 5, 0, 0},
 | |
| 			{4, 5, 3, 0, 0},
 | |
| 			{5, 3, 4, 0, 0},
 | |
| 			{5, 4, 3, 0, 0},
 | |
| 			{3, 4, 5, 10, 20},
 | |
| 			{3, 5, 4, 10, 20},
 | |
| 			{4, 3, 5, 10, 20},
 | |
| 			{4, 5, 3, 10, 20},
 | |
| 			{5, 3, 4, 10, 20},
 | |
| 			{5, 4, 3, 10, 20},
 | |
| 			{3, 4, 5, 20, 10},
 | |
| 			{3, 5, 4, 20, 10},
 | |
| 			{4, 3, 5, 20, 10},
 | |
| 			{4, 5, 3, 20, 10},
 | |
| 			{5, 3, 4, 20, 10},
 | |
| 			{5, 4, 3, 20, 10},
 | |
| 			{200, 300, 400, 0, 0},
 | |
| 			{200, 400, 300, 0, 0},
 | |
| 			{300, 200, 400, 0, 0},
 | |
| 			{300, 400, 200, 0, 0},
 | |
| 			{400, 200, 300, 0, 0},
 | |
| 			{400, 300, 200, 0, 0},
 | |
| 			{200, 300, 400, 500, 600},
 | |
| 			{200, 400, 300, 500, 600},
 | |
| 			{300, 200, 400, 500, 600},
 | |
| 			{300, 400, 200, 500, 600},
 | |
| 			{400, 200, 300, 500, 600},
 | |
| 			{400, 300, 200, 500, 600},
 | |
| 			{200, 300, 400, 600, 500},
 | |
| 			{200, 400, 300, 600, 500},
 | |
| 			{300, 200, 400, 600, 500},
 | |
| 			{300, 400, 200, 600, 500},
 | |
| 			{400, 200, 300, 600, 500},
 | |
| 			{400, 300, 200, 600, 500},
 | |
| 		} {
 | |
| 			m := test.m
 | |
| 			n := test.n
 | |
| 			nrhs := test.nrhs
 | |
| 
 | |
| 			lda := test.lda
 | |
| 			if lda == 0 {
 | |
| 				lda = n
 | |
| 			}
 | |
| 			a := make([]float64, m*lda)
 | |
| 			for i := range a {
 | |
| 				a[i] = rnd.Float64()
 | |
| 			}
 | |
| 			aCopy := make([]float64, len(a))
 | |
| 			copy(aCopy, a)
 | |
| 
 | |
| 			// Size of b is the same trans or no trans, because the number of rows
 | |
| 			// has to be the max of (m,n).
 | |
| 			mb := max(m, n)
 | |
| 			nb := nrhs
 | |
| 			ldb := test.ldb
 | |
| 			if ldb == 0 {
 | |
| 				ldb = nb
 | |
| 			}
 | |
| 			b := make([]float64, mb*ldb)
 | |
| 			for i := range b {
 | |
| 				b[i] = rnd.Float64()
 | |
| 			}
 | |
| 			bCopy := make([]float64, len(b))
 | |
| 			copy(bCopy, b)
 | |
| 
 | |
| 			// Find optimal work length.
 | |
| 			work := make([]float64, 1)
 | |
| 			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, -1)
 | |
| 
 | |
| 			// Perform linear solve
 | |
| 			work = make([]float64, int(work[0]))
 | |
| 			lwork := len(work)
 | |
| 			for i := range work {
 | |
| 				work[i] = rnd.Float64()
 | |
| 			}
 | |
| 			impl.Dgels(trans, m, n, nrhs, a, lda, b, ldb, work, lwork)
 | |
| 
 | |
| 			// Check that the answer is correct by comparing to the normal equations.
 | |
| 			aMat := blas64.General{
 | |
| 				Rows:   m,
 | |
| 				Cols:   n,
 | |
| 				Stride: lda,
 | |
| 				Data:   make([]float64, len(aCopy)),
 | |
| 			}
 | |
| 			copy(aMat.Data, aCopy)
 | |
| 			szAta := n
 | |
| 			if trans == blas.Trans {
 | |
| 				szAta = m
 | |
| 			}
 | |
| 			aTA := blas64.General{
 | |
| 				Rows:   szAta,
 | |
| 				Cols:   szAta,
 | |
| 				Stride: szAta,
 | |
| 				Data:   make([]float64, szAta*szAta),
 | |
| 			}
 | |
| 
 | |
| 			// Compute Aᵀ * A if notrans and A * Aᵀ otherwise.
 | |
| 			if trans == blas.NoTrans {
 | |
| 				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, aMat, 0, aTA)
 | |
| 			} else {
 | |
| 				blas64.Gemm(blas.NoTrans, blas.Trans, 1, aMat, aMat, 0, aTA)
 | |
| 			}
 | |
| 
 | |
| 			// Multiply by X.
 | |
| 			X := blas64.General{
 | |
| 				Rows:   szAta,
 | |
| 				Cols:   nrhs,
 | |
| 				Stride: ldb,
 | |
| 				Data:   b,
 | |
| 			}
 | |
| 			ans := blas64.General{
 | |
| 				Rows:   aTA.Rows,
 | |
| 				Cols:   X.Cols,
 | |
| 				Stride: X.Cols,
 | |
| 				Data:   make([]float64, aTA.Rows*X.Cols),
 | |
| 			}
 | |
| 			blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aTA, X, 0, ans)
 | |
| 
 | |
| 			B := blas64.General{
 | |
| 				Rows:   szAta,
 | |
| 				Cols:   nrhs,
 | |
| 				Stride: ldb,
 | |
| 				Data:   make([]float64, len(bCopy)),
 | |
| 			}
 | |
| 
 | |
| 			copy(B.Data, bCopy)
 | |
| 			var ans2 blas64.General
 | |
| 			if trans == blas.NoTrans {
 | |
| 				ans2 = blas64.General{
 | |
| 					Rows:   aMat.Cols,
 | |
| 					Cols:   B.Cols,
 | |
| 					Stride: B.Cols,
 | |
| 					Data:   make([]float64, aMat.Cols*B.Cols),
 | |
| 				}
 | |
| 			} else {
 | |
| 				ans2 = blas64.General{
 | |
| 					Rows:   aMat.Rows,
 | |
| 					Cols:   B.Cols,
 | |
| 					Stride: B.Cols,
 | |
| 					Data:   make([]float64, aMat.Rows*B.Cols),
 | |
| 				}
 | |
| 			}
 | |
| 
 | |
| 			// Compute Aᵀ B if Trans or A * B otherwise
 | |
| 			if trans == blas.NoTrans {
 | |
| 				blas64.Gemm(blas.Trans, blas.NoTrans, 1, aMat, B, 0, ans2)
 | |
| 			} else {
 | |
| 				blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, aMat, B, 0, ans2)
 | |
| 			}
 | |
| 			if !floats.EqualApprox(ans.Data, ans2.Data, 1e-12) {
 | |
| 				t.Errorf("Normal equations not satisfied")
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| }
 |