mirror of
https://github.com/gonum/gonum.git
synced 2025-10-07 08:01:20 +08:00
blas: add hermitian conversions
This commit is contained in:
@@ -108,108 +108,6 @@ func TestConvertGeneral(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newSymmetricFrom(a SymmetricCols) Symmetric {
|
||||
t := Symmetric{
|
||||
N: a.N,
|
||||
Stride: a.N,
|
||||
Data: make([]float32, a.N*a.N),
|
||||
Uplo: a.Uplo,
|
||||
}
|
||||
t.From(a)
|
||||
return t
|
||||
}
|
||||
|
||||
func (m Symmetric) n() int { return m.N }
|
||||
func (m Symmetric) at(i, j int) float32 {
|
||||
if m.Uplo == blas.Lower && i < j && j < m.N {
|
||||
i, j = j, i
|
||||
}
|
||||
if m.Uplo == blas.Upper && i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
return m.Data[i*m.Stride+j]
|
||||
}
|
||||
func (m Symmetric) uplo() blas.Uplo { return m.Uplo }
|
||||
|
||||
func newSymmetricColsFrom(a Symmetric) SymmetricCols {
|
||||
t := SymmetricCols{
|
||||
N: a.N,
|
||||
Stride: a.N,
|
||||
Data: make([]float32, a.N*a.N),
|
||||
Uplo: a.Uplo,
|
||||
}
|
||||
t.From(a)
|
||||
return t
|
||||
}
|
||||
|
||||
func (m SymmetricCols) n() int { return m.N }
|
||||
func (m SymmetricCols) at(i, j int) float32 {
|
||||
if m.Uplo == blas.Lower && i < j {
|
||||
i, j = j, i
|
||||
}
|
||||
if m.Uplo == blas.Upper && i > j && i < m.N {
|
||||
i, j = j, i
|
||||
}
|
||||
return m.Data[i+j*m.Stride]
|
||||
}
|
||||
func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo }
|
||||
|
||||
type symmetric interface {
|
||||
n() int
|
||||
at(i, j int) float32
|
||||
uplo() blas.Uplo
|
||||
}
|
||||
|
||||
func equalSymmetric(a, b symmetric) bool {
|
||||
an := a.n()
|
||||
bn := b.n()
|
||||
if an != bn {
|
||||
return false
|
||||
}
|
||||
if a.uplo() != b.uplo() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < an; i++ {
|
||||
for j := 0; j < an; j++ {
|
||||
if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var symmetricTests = []Symmetric{
|
||||
{N: 3, Stride: 3, Data: []float32{
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 9,
|
||||
}},
|
||||
{N: 3, Stride: 5, Data: []float32{
|
||||
1, 2, 3, 0, 0,
|
||||
4, 5, 6, 0, 0,
|
||||
7, 8, 9, 0, 0,
|
||||
}},
|
||||
}
|
||||
|
||||
func TestConvertSymmetric(t *testing.T) {
|
||||
for _, test := range symmetricTests {
|
||||
for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
|
||||
test.Uplo = uplo
|
||||
colmajor := newSymmetricColsFrom(test)
|
||||
if !equalSymmetric(colmajor, test) {
|
||||
t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
|
||||
colmajor, test)
|
||||
}
|
||||
rowmajor := newSymmetricFrom(colmajor)
|
||||
if !equalSymmetric(rowmajor, test) {
|
||||
t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
|
||||
rowmajor, test)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTriangularFrom(a TriangularCols) Triangular {
|
||||
t := Triangular{
|
||||
N: a.N,
|
||||
@@ -515,198 +413,6 @@ func TestConvertBand(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand {
|
||||
t := SymmetricBand{
|
||||
N: a.N,
|
||||
K: a.K,
|
||||
Stride: a.K + 1,
|
||||
Data: make([]float32, a.N*(a.K+1)),
|
||||
Uplo: a.Uplo,
|
||||
}
|
||||
for i := range t.Data {
|
||||
t.Data[i] = math.NaN()
|
||||
}
|
||||
t.From(a)
|
||||
return t
|
||||
}
|
||||
|
||||
func (m SymmetricBand) n() (n int) { return m.N }
|
||||
func (m SymmetricBand) at(i, j int) float32 {
|
||||
b := Band{
|
||||
Rows: m.N, Cols: m.N,
|
||||
Stride: m.Stride,
|
||||
Data: m.Data,
|
||||
}
|
||||
switch m.Uplo {
|
||||
default:
|
||||
panic("blas32: bad BLAS uplo")
|
||||
case blas.Upper:
|
||||
b.KU = m.K
|
||||
if i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
case blas.Lower:
|
||||
b.KL = m.K
|
||||
if i < j {
|
||||
i, j = j, i
|
||||
}
|
||||
}
|
||||
return b.at(i, j)
|
||||
}
|
||||
func (m SymmetricBand) bandwidth() (k int) { return m.K }
|
||||
func (m SymmetricBand) uplo() blas.Uplo { return m.Uplo }
|
||||
|
||||
func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols {
|
||||
t := SymmetricBandCols{
|
||||
N: a.N,
|
||||
K: a.K,
|
||||
Stride: a.K + 1,
|
||||
Data: make([]float32, a.N*(a.K+1)),
|
||||
Uplo: a.Uplo,
|
||||
}
|
||||
for i := range t.Data {
|
||||
t.Data[i] = math.NaN()
|
||||
}
|
||||
t.From(a)
|
||||
return t
|
||||
}
|
||||
|
||||
func (m SymmetricBandCols) n() (n int) { return m.N }
|
||||
func (m SymmetricBandCols) at(i, j int) float32 {
|
||||
b := BandCols{
|
||||
Rows: m.N, Cols: m.N,
|
||||
Stride: m.Stride,
|
||||
Data: m.Data,
|
||||
}
|
||||
switch m.Uplo {
|
||||
default:
|
||||
panic("blas32: bad BLAS uplo")
|
||||
case blas.Upper:
|
||||
b.KU = m.K
|
||||
if i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
case blas.Lower:
|
||||
b.KL = m.K
|
||||
if i < j {
|
||||
i, j = j, i
|
||||
}
|
||||
}
|
||||
return b.at(i, j)
|
||||
}
|
||||
func (m SymmetricBandCols) bandwidth() (k int) { return m.K }
|
||||
func (m SymmetricBandCols) uplo() blas.Uplo { return m.Uplo }
|
||||
|
||||
type symmetricBand interface {
|
||||
n() (n int)
|
||||
at(i, j int) float32
|
||||
bandwidth() (k int)
|
||||
uplo() blas.Uplo
|
||||
}
|
||||
|
||||
func equalSymmetricBand(a, b symmetricBand) bool {
|
||||
an := a.n()
|
||||
bn := b.n()
|
||||
if an != bn {
|
||||
return false
|
||||
}
|
||||
if a.uplo() != b.uplo() {
|
||||
return false
|
||||
}
|
||||
ak := a.bandwidth()
|
||||
bk := b.bandwidth()
|
||||
if ak != bk {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < an; i++ {
|
||||
for j := 0; j < an; j++ {
|
||||
if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var symmetricBandTests = []SymmetricBand{
|
||||
{N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float32{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
}},
|
||||
{N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float32{
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
}},
|
||||
{N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float32{
|
||||
1, 2,
|
||||
3, 4,
|
||||
5, -1,
|
||||
}},
|
||||
{N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float32{
|
||||
-1, 1,
|
||||
2, 3,
|
||||
4, 5,
|
||||
}},
|
||||
{N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float32{
|
||||
1, 2, 3,
|
||||
4, 5, -1,
|
||||
6, -2, -3,
|
||||
}},
|
||||
{N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float32{
|
||||
-2, -1, 1,
|
||||
-3, 2, 4,
|
||||
3, 5, 6,
|
||||
}},
|
||||
|
||||
{N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float32{
|
||||
1, 0, 0, 0, 0,
|
||||
2, 0, 0, 0, 0,
|
||||
3, 0, 0, 0, 0,
|
||||
}},
|
||||
{N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float32{
|
||||
1, 0, 0, 0, 0,
|
||||
2, 0, 0, 0, 0,
|
||||
3, 0, 0, 0, 0,
|
||||
}},
|
||||
{N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float32{
|
||||
1, 2, 0, 0, 0,
|
||||
3, 4, 0, 0, 0,
|
||||
5, -1, 0, 0, 0,
|
||||
}},
|
||||
{N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float32{
|
||||
-1, 1, 0, 0, 0,
|
||||
2, 3, 0, 0, 0,
|
||||
4, 5, 0, 0, 0,
|
||||
}},
|
||||
{N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float32{
|
||||
1, 2, 3, 0, 0,
|
||||
4, 5, -1, 0, 0,
|
||||
6, -2, -3, 0, 0,
|
||||
}},
|
||||
{N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float32{
|
||||
-2, -1, 1, 0, 0,
|
||||
-3, 2, 4, 0, 0,
|
||||
3, 5, 6, 0, 0,
|
||||
}},
|
||||
}
|
||||
|
||||
func TestConvertSymBand(t *testing.T) {
|
||||
for _, test := range symmetricBandTests {
|
||||
colmajor := newSymmetricBandColsFrom(test)
|
||||
if !equalSymmetricBand(colmajor, test) {
|
||||
t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v",
|
||||
colmajor, test)
|
||||
}
|
||||
rowmajor := newSymmetricBandFrom(colmajor)
|
||||
if !equalSymmetricBand(rowmajor, test) {
|
||||
t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v",
|
||||
rowmajor, test)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newTriangularBandFrom(a TriangularBandCols) TriangularBand {
|
||||
t := TriangularBand{
|
||||
N: a.N,
|
||||
|
Reference in New Issue
Block a user