mirror of
https://github.com/gonum/gonum.git
synced 2025-10-09 00:50:16 +08:00

This merges the three packages, matrix, mat64, and cmat128. It then renames this big package to mat. It fixes the import statements and corresponding code
118 lines
3.0 KiB
Go
118 lines
3.0 KiB
Go
// Copyright ©2017 The gonum Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE file.
|
|
|
|
package mat
|
|
|
|
import (
|
|
"math/rand"
|
|
"testing"
|
|
|
|
"gonum.org/v1/gonum/floats"
|
|
)
|
|
|
|
func TestGSVD(t *testing.T) {
|
|
const tol = 1e-10
|
|
rnd := rand.New(rand.NewSource(1))
|
|
for _, test := range []struct {
|
|
m, p, n int
|
|
}{
|
|
{5, 3, 5},
|
|
{5, 3, 3},
|
|
{3, 3, 5},
|
|
{5, 5, 5},
|
|
{5, 5, 3},
|
|
{3, 5, 5},
|
|
{150, 150, 150},
|
|
{200, 150, 150},
|
|
{150, 150, 200},
|
|
{150, 200, 150},
|
|
{200, 200, 150},
|
|
{150, 200, 200},
|
|
} {
|
|
m := test.m
|
|
p := test.p
|
|
n := test.n
|
|
for trial := 0; trial < 10; trial++ {
|
|
a := NewDense(m, n, nil)
|
|
for i := range a.mat.Data {
|
|
a.mat.Data[i] = rnd.NormFloat64()
|
|
}
|
|
aCopy := DenseCopyOf(a)
|
|
|
|
b := NewDense(p, n, nil)
|
|
for i := range b.mat.Data {
|
|
b.mat.Data[i] = rnd.NormFloat64()
|
|
}
|
|
bCopy := DenseCopyOf(b)
|
|
|
|
// Test Full decomposition.
|
|
var gsvd GSVD
|
|
ok := gsvd.Factorize(a, b, GSVDU|GSVDV|GSVDQ)
|
|
if !ok {
|
|
t.Errorf("GSVD factorization failed")
|
|
}
|
|
if !Equal(a, aCopy) {
|
|
t.Errorf("A changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
|
|
}
|
|
if !Equal(b, bCopy) {
|
|
t.Errorf("B changed during call to GSVD.Factorize with GSVDU|GSVDV|GSVDQ")
|
|
}
|
|
c, s, sigma1, sigma2, zeroR, u, v, q := extractGSVD(&gsvd)
|
|
var ansU, ansV, d1R, d2R Dense
|
|
ansU.Product(u.T(), a, q)
|
|
ansV.Product(v.T(), b, q)
|
|
d1R.Mul(sigma1, zeroR)
|
|
d2R.Mul(sigma2, zeroR)
|
|
if !EqualApprox(&ansU, &d1R, tol) {
|
|
t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nU^T * A * Q:\n% 0.2f\nΣ₁ * [ 0 R ]:\n% 0.2f",
|
|
Formatted(&ansU), Formatted(&d1R))
|
|
}
|
|
if !EqualApprox(&ansV, &d2R, tol) {
|
|
t.Errorf("Answer mismatch with GSVDU|GSVDV|GSVDQ\nV^T * B *Q:\n% 0.2f\nΣ₂ * [ 0 R ]:\n% 0.2f",
|
|
Formatted(&d2R), Formatted(&ansV))
|
|
}
|
|
|
|
// Check C^2 + S^2 = I.
|
|
for i := range c {
|
|
d := c[i]*c[i] + s[i]*s[i]
|
|
if !floats.EqualWithinAbsOrRel(d, 1, 1e-14, 1e-14) {
|
|
t.Errorf("c_%d^2 + s_%d^2 != 1: got: %v", i, i, d)
|
|
}
|
|
}
|
|
|
|
// Test None decomposition.
|
|
ok = gsvd.Factorize(a, b, GSVDNone)
|
|
if !ok {
|
|
t.Errorf("GSVD factorization failed")
|
|
}
|
|
if !Equal(a, aCopy) {
|
|
t.Errorf("A changed during call to GSVD with GSVDNone")
|
|
}
|
|
if !Equal(b, bCopy) {
|
|
t.Errorf("B changed during call to GSVD with GSVDNone")
|
|
}
|
|
cNone := gsvd.ValuesA(nil)
|
|
if !floats.EqualApprox(c, cNone, tol) {
|
|
t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
|
|
}
|
|
sNone := gsvd.ValuesB(nil)
|
|
if !floats.EqualApprox(s, sNone, tol) {
|
|
t.Errorf("Singular value mismatch between GSVDU|GSVDV|GSVDQ and GSVDNone decomposition")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func extractGSVD(gsvd *GSVD) (c, s []float64, s1, s2, zR, u, v, q *Dense) {
|
|
s1 = gsvd.SigmaATo(nil)
|
|
s2 = gsvd.SigmaBTo(nil)
|
|
zR = gsvd.ZeroRTo(nil)
|
|
u = gsvd.UTo(nil)
|
|
v = gsvd.VTo(nil)
|
|
q = gsvd.QTo(nil)
|
|
c = gsvd.ValuesA(nil)
|
|
s = gsvd.ValuesB(nil)
|
|
return c, s, s1, s2, zR, u, v, q
|
|
}
|