mirror of
https://github.com/gonum/gonum.git
synced 2025-10-04 14:52:57 +08:00
272 lines
7.0 KiB
Go
272 lines
7.0 KiB
Go
// Copyright ©2015 The Gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package testlapack
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
|
|
"golang.org/x/exp/rand"
|
|
|
|
"gonum.org/v1/gonum/blas"
|
|
"gonum.org/v1/gonum/blas/blas64"
|
|
"gonum.org/v1/gonum/floats"
|
|
"gonum.org/v1/gonum/lapack"
|
|
)
|
|
|
|
type Dgesvder interface {
|
|
Dgesvd(jobU, jobVT lapack.SVDJob, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, lwork int) (ok bool)
|
|
}
|
|
|
|
func DgesvdTest(t *testing.T, impl Dgesvder) {
|
|
rnd := rand.New(rand.NewSource(1))
|
|
// TODO(btracey): Add tests for all of the cases when the SVD implementation
|
|
// is finished.
|
|
// TODO(btracey): Add tests for m > mnthr and n > mnthr when other SVD
|
|
// conditions are implemented. Right now mnthr is 5,000,000 which is too
|
|
// large to create a square matrix of that size.
|
|
for _, test := range []struct {
|
|
m, n, lda, ldu, ldvt int
|
|
}{
|
|
{5, 5, 0, 0, 0},
|
|
{5, 6, 0, 0, 0},
|
|
{6, 5, 0, 0, 0},
|
|
{5, 9, 0, 0, 0},
|
|
{9, 5, 0, 0, 0},
|
|
|
|
{5, 5, 10, 11, 12},
|
|
{5, 6, 10, 11, 12},
|
|
{6, 5, 10, 11, 12},
|
|
{5, 5, 10, 11, 12},
|
|
{5, 9, 10, 11, 12},
|
|
{9, 5, 10, 11, 12},
|
|
|
|
{300, 300, 0, 0, 0},
|
|
{300, 400, 0, 0, 0},
|
|
{400, 300, 0, 0, 0},
|
|
{300, 600, 0, 0, 0},
|
|
{600, 300, 0, 0, 0},
|
|
|
|
{300, 300, 400, 450, 460},
|
|
{300, 400, 500, 550, 560},
|
|
{400, 300, 550, 550, 560},
|
|
{300, 600, 700, 750, 760},
|
|
{600, 300, 700, 750, 760},
|
|
} {
|
|
jobU := lapack.SVDAll
|
|
jobVT := lapack.SVDAll
|
|
|
|
m := test.m
|
|
n := test.n
|
|
lda := test.lda
|
|
if lda == 0 {
|
|
lda = n
|
|
}
|
|
ldu := test.ldu
|
|
if ldu == 0 {
|
|
ldu = m
|
|
}
|
|
ldvt := test.ldvt
|
|
if ldvt == 0 {
|
|
ldvt = n
|
|
}
|
|
|
|
a := make([]float64, m*lda)
|
|
for i := range a {
|
|
a[i] = rnd.NormFloat64()
|
|
}
|
|
|
|
u := make([]float64, m*ldu)
|
|
for i := range u {
|
|
u[i] = rnd.NormFloat64()
|
|
}
|
|
|
|
vt := make([]float64, n*ldvt)
|
|
for i := range vt {
|
|
vt[i] = rnd.NormFloat64()
|
|
}
|
|
|
|
uAllOrig := make([]float64, len(u))
|
|
copy(uAllOrig, u)
|
|
vtAllOrig := make([]float64, len(vt))
|
|
copy(vtAllOrig, vt)
|
|
aCopy := make([]float64, len(a))
|
|
copy(aCopy, a)
|
|
|
|
s := make([]float64, min(m, n))
|
|
|
|
work := make([]float64, 1)
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, -1)
|
|
|
|
if !floats.Equal(a, aCopy) {
|
|
t.Errorf("a changed during call to get work length")
|
|
}
|
|
|
|
work = make([]float64, int(work[0]))
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
|
|
|
|
errStr := fmt.Sprintf("m = %v, n = %v, lda = %v, ldu = %v, ldv = %v", m, n, lda, ldu, ldvt)
|
|
svdCheck(t, false, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
|
|
svdCheckPartial(t, impl, lapack.SVDAll, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
|
|
|
|
// Test InPlace
|
|
jobU = lapack.SVDInPlace
|
|
jobVT = lapack.SVDInPlace
|
|
copy(a, aCopy)
|
|
copy(u, uAllOrig)
|
|
copy(vt, vtAllOrig)
|
|
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, vt, ldvt, work, len(work))
|
|
svdCheck(t, true, errStr, m, n, s, a, u, ldu, vt, ldvt, aCopy, lda)
|
|
svdCheckPartial(t, impl, lapack.SVDInPlace, errStr, uAllOrig, vtAllOrig, aCopy, m, n, a, lda, s, u, ldu, vt, ldvt, work, false)
|
|
}
|
|
}
|
|
|
|
// svdCheckPartial checks that the singular values and vectors are computed when
|
|
// not all of them are computed.
|
|
func svdCheckPartial(t *testing.T, impl Dgesvder, job lapack.SVDJob, errStr string, uAllOrig, vtAllOrig, aCopy []float64, m, n int, a []float64, lda int, s, u []float64, ldu int, vt []float64, ldvt int, work []float64, shortWork bool) {
|
|
rnd := rand.New(rand.NewSource(1))
|
|
jobU := job
|
|
jobVT := job
|
|
// Compare the singular values when computed with {SVDNone, SVDNone.}
|
|
sCopy := make([]float64, len(s))
|
|
copy(sCopy, s)
|
|
copy(a, aCopy)
|
|
for i := range s {
|
|
s[i] = rnd.Float64()
|
|
}
|
|
tmp1 := make([]float64, 1)
|
|
tmp2 := make([]float64, 1)
|
|
jobU = lapack.SVDNone
|
|
jobVT = lapack.SVDNone
|
|
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, -1)
|
|
work = make([]float64, int(work[0]))
|
|
lwork := len(work)
|
|
if shortWork {
|
|
lwork--
|
|
}
|
|
ok := impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, tmp2, ldvt, work, lwork)
|
|
if !ok {
|
|
t.Errorf("Dgesvd did not complete successfully")
|
|
}
|
|
if !floats.EqualApprox(s, sCopy, 1e-10) {
|
|
t.Errorf("Singular value mismatch when singular vectors not computed: %s", errStr)
|
|
}
|
|
// Check that the singular vectors are correctly computed when the other
|
|
// is none.
|
|
uAll := make([]float64, len(u))
|
|
copy(uAll, u)
|
|
vtAll := make([]float64, len(vt))
|
|
copy(vtAll, vt)
|
|
|
|
// Copy the original vectors so the data outside the matrix bounds is the same.
|
|
copy(u, uAllOrig)
|
|
copy(vt, vtAllOrig)
|
|
|
|
jobU = job
|
|
jobVT = lapack.SVDNone
|
|
copy(a, aCopy)
|
|
for i := range s {
|
|
s[i] = rnd.Float64()
|
|
}
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, -1)
|
|
work = make([]float64, int(work[0]))
|
|
lwork = len(work)
|
|
if shortWork {
|
|
lwork--
|
|
}
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, u, ldu, tmp2, ldvt, work, len(work))
|
|
if !floats.EqualApprox(uAll, u, 1e-10) {
|
|
t.Errorf("U mismatch when VT is not computed: %s", errStr)
|
|
}
|
|
if !floats.EqualApprox(s, sCopy, 1e-10) {
|
|
t.Errorf("Singular value mismatch when U computed VT not")
|
|
}
|
|
jobU = lapack.SVDNone
|
|
jobVT = job
|
|
copy(a, aCopy)
|
|
for i := range s {
|
|
s[i] = rnd.Float64()
|
|
}
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, -1)
|
|
work = make([]float64, int(work[0]))
|
|
lwork = len(work)
|
|
if shortWork {
|
|
lwork--
|
|
}
|
|
impl.Dgesvd(jobU, jobVT, m, n, a, lda, s, tmp1, ldu, vt, ldvt, work, len(work))
|
|
if !floats.EqualApprox(vtAll, vt, 1e-10) {
|
|
t.Errorf("VT mismatch when U is not computed: %s", errStr)
|
|
}
|
|
if !floats.EqualApprox(s, sCopy, 1e-10) {
|
|
t.Errorf("Singular value mismatch when VT computed U not")
|
|
}
|
|
}
|
|
|
|
// svdCheck checks that the singular value decomposition correctly multiplies back
|
|
// to the original matrix.
|
|
func svdCheck(t *testing.T, thin bool, errStr string, m, n int, s, a, u []float64, ldu int, vt []float64, ldvt int, aCopy []float64, lda int) {
|
|
sigma := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: n,
|
|
Data: make([]float64, m*n),
|
|
}
|
|
for i := 0; i < min(m, n); i++ {
|
|
sigma.Data[i*sigma.Stride+i] = s[i]
|
|
}
|
|
|
|
uMat := blas64.General{
|
|
Rows: m,
|
|
Cols: m,
|
|
Stride: ldu,
|
|
Data: u,
|
|
}
|
|
vTMat := blas64.General{
|
|
Rows: n,
|
|
Cols: n,
|
|
Stride: ldvt,
|
|
Data: vt,
|
|
}
|
|
if thin {
|
|
sigma.Rows = min(m, n)
|
|
sigma.Cols = min(m, n)
|
|
uMat.Cols = min(m, n)
|
|
vTMat.Rows = min(m, n)
|
|
}
|
|
|
|
tmp := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: n,
|
|
Data: make([]float64, m*n),
|
|
}
|
|
ans := blas64.General{
|
|
Rows: m,
|
|
Cols: n,
|
|
Stride: lda,
|
|
Data: make([]float64, m*lda),
|
|
}
|
|
copy(ans.Data, a)
|
|
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, uMat, sigma, 0, tmp)
|
|
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, tmp, vTMat, 0, ans)
|
|
|
|
if !floats.EqualApprox(ans.Data, aCopy, 1e-8) {
|
|
t.Errorf("Decomposition mismatch. Trim = %v, %s", thin, errStr)
|
|
}
|
|
|
|
if !thin {
|
|
// Check that U and V are orthogonal.
|
|
if !isOrthogonal(uMat) {
|
|
t.Errorf("U not orthogonal %s", errStr)
|
|
}
|
|
if !isOrthogonal(vTMat) {
|
|
t.Errorf("V not orthogonal %s", errStr)
|
|
}
|
|
}
|
|
}
|