mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 07:37:03 +08:00
mat: add non-conjugate transpose for complex matrices
This commit is contained in:
@@ -4,7 +4,11 @@
|
||||
|
||||
package mat
|
||||
|
||||
import "gonum.org/v1/gonum/blas/cblas128"
|
||||
import (
|
||||
"math/cmplx"
|
||||
|
||||
"gonum.org/v1/gonum/blas/cblas128"
|
||||
)
|
||||
|
||||
var (
|
||||
cDense *CDense
|
||||
@@ -29,12 +33,56 @@ func (m *CDense) Dims() (r, c int) {
|
||||
func (m *CDense) Caps() (r, c int) { return m.capRows, m.capCols }
|
||||
|
||||
// H performs an implicit conjugate transpose by returning the receiver inside a
|
||||
// Conjugate.
|
||||
// ConjTranspose.
|
||||
func (m *CDense) H() CMatrix {
|
||||
return Conjugate{m}
|
||||
return ConjTranspose{m}
|
||||
}
|
||||
|
||||
// Slice returns a new Matrix that shares backing data with the receiver.
|
||||
// T performs an implicit transpose by returning the receiver inside a
|
||||
// CTranspose.
|
||||
func (m *CDense) T() CMatrix {
|
||||
return CTranspose{m}
|
||||
}
|
||||
|
||||
// Conj calculates the element-wise conjugate of a and stores the result in the
|
||||
// receiver.
|
||||
// Conj will panic if m and a do not have the same dimension unless m is empty.
|
||||
func (m *CDense) Conj(a CMatrix) {
|
||||
ar, ac := a.Dims()
|
||||
aU, aTrans, aConj := untransposeExtractCmplx(a)
|
||||
m.reuseAsNonZeroed(ar, ac)
|
||||
|
||||
if arm, ok := a.(*CDense); ok {
|
||||
amat := arm.mat
|
||||
if m != aU {
|
||||
m.checkOverlap(amat)
|
||||
}
|
||||
for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
|
||||
for i, v := range amat.Data[ja : ja+ac] {
|
||||
m.mat.Data[i+jm] = cmplx.Conj(v)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
m.checkOverlapMatrix(aU)
|
||||
if aTrans != aConj && m == aU {
|
||||
// Only make workspace if the destination is transposed
|
||||
// with respect to the source and they are the same
|
||||
// matrix.
|
||||
var restore func()
|
||||
m, restore = m.isolatedWorkspace(aU)
|
||||
defer restore()
|
||||
}
|
||||
|
||||
for r := 0; r < ar; r++ {
|
||||
for c := 0; c < ac; c++ {
|
||||
m.set(r, c, cmplx.Conj(a.At(r, c)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Slice returns a new CMatrix that shares backing data with the receiver.
|
||||
// The returned matrix starts at {i,j} of the receiver and extends k-i rows
|
||||
// and l-j columns. The final row in the resulting matrix is k-1 and the
|
||||
// final column is l-1.
|
||||
@@ -61,7 +109,6 @@ func (m *CDense) slice(i, k, j, l int) *CDense {
|
||||
return &t
|
||||
}
|
||||
|
||||
|
||||
// NewCDense creates a new complex Dense matrix with r rows and c columns.
|
||||
// If data == nil, a new slice is allocated for the backing slice.
|
||||
// If len(data) == r*c, data is used as the backing slice, and changes to the
|
||||
@@ -171,6 +218,21 @@ func (m *CDense) reuseAsZeroed(r, c int) {
|
||||
m.Zero()
|
||||
}
|
||||
|
||||
// isolatedWorkspace returns a new dense matrix w with the size of a and
|
||||
// returns a callback to defer which performs cleanup at the return of the call.
|
||||
// This should be used when a method receiver is the same pointer as an input argument.
|
||||
func (m *CDense) isolatedWorkspace(a CMatrix) (w *CDense, restore func()) {
|
||||
r, c := a.Dims()
|
||||
if r == 0 || c == 0 {
|
||||
panic(ErrZeroLength)
|
||||
}
|
||||
w = getWorkspaceCmplx(r, c, false)
|
||||
return w, func() {
|
||||
m.Copy(w)
|
||||
putWorkspaceCmplx(w)
|
||||
}
|
||||
}
|
||||
|
||||
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
||||
// receiver of a dimensionally restricted operation.
|
||||
//
|
||||
@@ -227,6 +289,14 @@ func (m *CDense) Copy(a CMatrix) (r, c int) {
|
||||
return r, c
|
||||
}
|
||||
|
||||
// SetRawCMatrix sets the underlying cblas128.General used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in b.
|
||||
func (m *CDense) SetRawCMatrix(b cblas128.General) {
|
||||
m.capRows, m.capCols = b.Rows, b.Cols
|
||||
m.mat = b
|
||||
}
|
||||
|
||||
// RawCMatrix returns the underlying cblas128.General used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in returned cblas128.General.
|
||||
|
@@ -4,7 +4,12 @@
|
||||
|
||||
package mat
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"math/cmplx"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/rand"
|
||||
)
|
||||
|
||||
func TestCDenseNewAtSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -48,6 +53,96 @@ func TestCDenseNewAtSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCDenseConjElem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
|
||||
for r := 1; r <= 8; r++ {
|
||||
for c := 1; c <= 8; c++ {
|
||||
const (
|
||||
empty = iota
|
||||
fit
|
||||
sliced
|
||||
self
|
||||
)
|
||||
for _, dst := range []int{empty, fit, sliced, self} {
|
||||
const (
|
||||
noTrans = iota
|
||||
trans
|
||||
conjTrans
|
||||
bothHT
|
||||
bothTH
|
||||
)
|
||||
for _, src := range []int{noTrans, trans, conjTrans, bothHT, bothTH} {
|
||||
d := NewCDense(r, c, nil)
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
d.Set(i, j, complex(rnd.NormFloat64(), rnd.NormFloat64()))
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
a CMatrix
|
||||
op string
|
||||
)
|
||||
switch src {
|
||||
case noTrans:
|
||||
a = d
|
||||
case trans:
|
||||
r, c = c, r
|
||||
a = d.T()
|
||||
op = ".T"
|
||||
case conjTrans:
|
||||
r, c = c, r
|
||||
a = d.H()
|
||||
op = ".H"
|
||||
case bothHT:
|
||||
a = d.H().T()
|
||||
op = ".H.T"
|
||||
case bothTH:
|
||||
a = d.T().H()
|
||||
op = ".T.H"
|
||||
default:
|
||||
panic("invalid src op")
|
||||
}
|
||||
aCopy := NewCDense(r, c, nil)
|
||||
aCopy.Copy(a)
|
||||
|
||||
var got *CDense
|
||||
switch dst {
|
||||
case empty:
|
||||
got = &CDense{}
|
||||
case fit:
|
||||
got = NewCDense(r, c, nil)
|
||||
case sliced:
|
||||
got = NewCDense(r*2, c*2, nil).Slice(1, r+1, 1, c+1).(*CDense)
|
||||
case self:
|
||||
if r != c && (src == conjTrans || src == trans) {
|
||||
continue
|
||||
}
|
||||
got = d
|
||||
default:
|
||||
panic("invalid dst size")
|
||||
}
|
||||
|
||||
got.Conj(a)
|
||||
|
||||
for i := 0; i < r; i++ {
|
||||
for j := 0; j < c; j++ {
|
||||
if got.At(i, j) != cmplx.Conj(aCopy.At(i, j)) {
|
||||
t.Errorf("unexpected results a%s[%d, %d] for r=%d c=%d %v != %v",
|
||||
op, i, j, r, c, got.At(i, j), cmplx.Conj(a.At(i, j)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCDenseGrow(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := &CDense{}
|
||||
|
157
mat/cmatrix.go
157
mat/cmatrix.go
@@ -14,18 +14,24 @@ import (
|
||||
|
||||
// CMatrix is the basic matrix interface type for complex matrices.
|
||||
type CMatrix interface {
|
||||
// Dims returns the dimensions of a Matrix.
|
||||
// Dims returns the dimensions of a CMatrix.
|
||||
Dims() (r, c int)
|
||||
|
||||
// At returns the value of a matrix element at row i, column j.
|
||||
// It will panic if i or j are out of bounds for the matrix.
|
||||
At(i, j int) complex128
|
||||
|
||||
// H returns the conjugate transpose of the Matrix. Whether H
|
||||
// H returns the conjugate transpose of the CMatrix. Whether H
|
||||
// returns a copy of the underlying data is implementation dependent.
|
||||
// This method may be implemented using the Conjugate type, which
|
||||
// This method may be implemented using the ConjTranspose type, which
|
||||
// provides an implicit matrix conjugate transpose.
|
||||
H() CMatrix
|
||||
|
||||
// T returns the transpose of the CMatrix. Whether T returns a copy of the
|
||||
// underlying data is implementation dependent.
|
||||
// This method may be implemented using the CTranspose type, which
|
||||
// provides an implicit matrix transpose.
|
||||
T() CMatrix
|
||||
}
|
||||
|
||||
// A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
|
||||
@@ -35,53 +41,109 @@ type RawCMatrixer interface {
|
||||
}
|
||||
|
||||
var (
|
||||
_ CMatrix = Conjugate{}
|
||||
_ Unconjugator = Conjugate{}
|
||||
_ CMatrix = ConjTranspose{}
|
||||
_ UnConjTransposer = ConjTranspose{}
|
||||
)
|
||||
|
||||
// Conjugate is a type for performing an implicit matrix conjugate transpose.
|
||||
// It implements the Matrix interface, returning values from the conjugate
|
||||
// ConjTranspose is a type for performing an implicit matrix conjugate transpose.
|
||||
// It implements the CMatrix interface, returning values from the conjugate
|
||||
// transpose of the matrix within.
|
||||
type Conjugate struct {
|
||||
type ConjTranspose struct {
|
||||
CMatrix CMatrix
|
||||
}
|
||||
|
||||
// At returns the value of the element at row i and column j of the conjugate
|
||||
// transposed matrix, that is, row j and column i of the Matrix field.
|
||||
func (t Conjugate) At(i, j int) complex128 {
|
||||
// transposed matrix, that is, row j and column i of the CMatrix field.
|
||||
func (t ConjTranspose) At(i, j int) complex128 {
|
||||
z := t.CMatrix.At(j, i)
|
||||
return cmplx.Conj(z)
|
||||
}
|
||||
|
||||
// Dims returns the dimensions of the transposed matrix. The number of rows returned
|
||||
// is the number of columns in the Matrix field, and the number of columns is
|
||||
// the number of rows in the Matrix field.
|
||||
func (t Conjugate) Dims() (r, c int) {
|
||||
// is the number of columns in the CMatrix field, and the number of columns is
|
||||
// the number of rows in the CMatrix field.
|
||||
func (t ConjTranspose) Dims() (r, c int) {
|
||||
c, r = t.CMatrix.Dims()
|
||||
return r, c
|
||||
}
|
||||
|
||||
// H performs an implicit conjugate transpose by returning the Matrix field.
|
||||
func (t Conjugate) H() CMatrix {
|
||||
// H performs an implicit conjugate transpose by returning the CMatrix field.
|
||||
func (t ConjTranspose) H() CMatrix {
|
||||
return t.CMatrix
|
||||
}
|
||||
|
||||
// Unconjugate returns the Matrix field.
|
||||
func (t Conjugate) Unconjugate() CMatrix {
|
||||
// T performs an implicit transpose by returning the receiver inside a
|
||||
// CTranspose.
|
||||
func (t ConjTranspose) T() CMatrix {
|
||||
return CTranspose{t}
|
||||
}
|
||||
|
||||
// UnConjTranspose returns the CMatrix field.
|
||||
func (t ConjTranspose) UnConjTranspose() CMatrix {
|
||||
return t.CMatrix
|
||||
}
|
||||
|
||||
// Unconjugator is a type that can undo an implicit conjugate transpose.
|
||||
type Unconjugator interface {
|
||||
// CTranspose is a type for performing an implicit matrix conjugate transpose.
|
||||
// It implements the CMatrix interface, returning values from the conjugate
|
||||
// transpose of the matrix within.
|
||||
type CTranspose struct {
|
||||
CMatrix CMatrix
|
||||
}
|
||||
|
||||
// At returns the value of the element at row i and column j of the conjugate
|
||||
// transposed matrix, that is, row j and column i of the CMatrix field.
|
||||
func (t CTranspose) At(i, j int) complex128 {
|
||||
return t.CMatrix.At(j, i)
|
||||
}
|
||||
|
||||
// Dims returns the dimensions of the transposed matrix. The number of rows returned
|
||||
// is the number of columns in the CMatrix field, and the number of columns is
|
||||
// the number of rows in the CMatrix field.
|
||||
func (t CTranspose) Dims() (r, c int) {
|
||||
c, r = t.CMatrix.Dims()
|
||||
return r, c
|
||||
}
|
||||
|
||||
// H performs an implicit transpose by returning the receiver inside a
|
||||
// ConjTranspose.
|
||||
func (t CTranspose) H() CMatrix {
|
||||
return ConjTranspose{t}
|
||||
}
|
||||
|
||||
// T performs an implicit conjugate transpose by returning the CMatrix field.
|
||||
func (t CTranspose) T() CMatrix {
|
||||
return t.CMatrix
|
||||
}
|
||||
|
||||
// Untranspose returns the CMatrix field.
|
||||
func (t CTranspose) Untranspose() CMatrix {
|
||||
return t.CMatrix
|
||||
}
|
||||
|
||||
// UnConjTransposer is a type that can undo an implicit conjugate transpose.
|
||||
type UnConjTransposer interface {
|
||||
// UnConjTranspose returns the underlying CMatrix stored for the implicit
|
||||
// conjugate transpose.
|
||||
UnConjTranspose() CMatrix
|
||||
|
||||
// Note: This interface is needed to unify all of the Conjugate types. In
|
||||
// the cmat128 methods, we need to test if the Matrix has been implicitly
|
||||
// the cmat128 methods, we need to test if the CMatrix has been implicitly
|
||||
// transposed. If this is checked by testing for the specific Conjugate type
|
||||
// then the behavior will be different if the user uses H() or HTri() for a
|
||||
// triangular matrix.
|
||||
}
|
||||
|
||||
// Unconjugate returns the underlying Matrix stored for the implicit
|
||||
// conjugate transpose.
|
||||
Unconjugate() CMatrix
|
||||
// CUntransposer is a type that can undo an implicit transpose.
|
||||
type CUntransposer interface {
|
||||
// Untranspose returns the underlying CMatrix stored for the implicit
|
||||
// transpose.
|
||||
Untranspose() CMatrix
|
||||
|
||||
// Note: This interface is needed to unify all of the CTranspose types. In
|
||||
// the cmat128 methods, we need to test if the CMatrix has been implicitly
|
||||
// transposed. If this is checked by testing for the specific CTranspose type
|
||||
// then the behavior will be different if the user uses T() or TTri() for a
|
||||
// triangular matrix.
|
||||
}
|
||||
|
||||
// useC returns a complex128 slice with l elements, using c if it
|
||||
@@ -112,14 +174,49 @@ func zeroC(c []complex128) {
|
||||
}
|
||||
}
|
||||
|
||||
// unconjugate unconjugates a matrix if applicable. If a is an Unconjugator, then
|
||||
// unconjugate returns the underlying matrix and true. If it is not, then it returns
|
||||
// the input matrix and false.
|
||||
func unconjugate(a CMatrix) (CMatrix, bool) {
|
||||
if ut, ok := a.(Unconjugator); ok {
|
||||
return ut.Unconjugate(), true
|
||||
// untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer
|
||||
// or an UnConjTransposer, then untranspose returns the underlying matrix and true for
|
||||
// the kind of transpose (potentially both).
|
||||
// If it is not, then it returns the input matrix and false for trans and conj.
|
||||
func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
|
||||
switch ut := a.(type) {
|
||||
case CUntransposer:
|
||||
trans = true
|
||||
u := ut.Untranspose()
|
||||
if uc, ok := u.(UnConjTransposer); ok {
|
||||
return uc.UnConjTranspose(), trans, true
|
||||
}
|
||||
return u, trans, false
|
||||
case UnConjTransposer:
|
||||
conj = true
|
||||
u := ut.UnConjTranspose()
|
||||
if ut, ok := u.(CUntransposer); ok {
|
||||
return ut.Untranspose(), true, conj
|
||||
}
|
||||
return u, false, conj
|
||||
default:
|
||||
return a, false, false
|
||||
}
|
||||
}
|
||||
|
||||
// untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type.
|
||||
//
|
||||
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
|
||||
// Otherwise, if it implements a Raw method, an appropriate built-in type value
|
||||
// is returned holding the raw matrix value of the input. If neither of these
|
||||
// is possible, the untransposed matrix is returned.
|
||||
func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
|
||||
ut, trans, conj := untransposeCmplx(a)
|
||||
switch m := ut.(type) {
|
||||
case *CDense:
|
||||
return m, trans, conj
|
||||
case RawCMatrixer:
|
||||
var d CDense
|
||||
d.SetRawCMatrix(m.RawCMatrix())
|
||||
return &d, trans, conj
|
||||
default:
|
||||
return ut, trans, conj
|
||||
}
|
||||
return a, false
|
||||
}
|
||||
|
||||
// CEqual returns whether the matrices a and b have the same size
|
||||
|
35
mat/pool.go
35
mat/pool.go
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
"gonum.org/v1/gonum/blas/cblas128"
|
||||
)
|
||||
|
||||
var tab64 = [64]byte{
|
||||
@@ -59,6 +60,9 @@ var (
|
||||
|
||||
// poolInts is the []int equivalent of pool.
|
||||
poolInts [63]sync.Pool
|
||||
|
||||
// poolCmplx is the CDense equivalent of pool.
|
||||
poolCmplx [63]sync.Pool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -94,6 +98,12 @@ func init() {
|
||||
s := make([]int, l)
|
||||
return &s
|
||||
}
|
||||
|
||||
poolCmplx[i].New = func() interface{} {
|
||||
return &CDense{mat: cblas128.General{
|
||||
Data: make([]complex128, l),
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,3 +246,28 @@ func getInts(l int, clear bool) []int {
|
||||
func putInts(w []int) {
|
||||
poolInts[bits(uint64(cap(w)))].Put(&w)
|
||||
}
|
||||
|
||||
// getWorkspaceCmplx returns a *CDense of size r×c and a data slice
|
||||
// with a cap that is less than 2*r*c. If clear is true, the
|
||||
// data slice visible through the CMatrix interface is zeroed.
|
||||
func getWorkspaceCmplx(r, c int, clear bool) *CDense {
|
||||
l := uint64(r * c)
|
||||
w := poolCmplx[bits(l)].Get().(*CDense)
|
||||
w.mat.Data = w.mat.Data[:l]
|
||||
if clear {
|
||||
zeroC(w.mat.Data)
|
||||
}
|
||||
w.mat.Rows = r
|
||||
w.mat.Cols = c
|
||||
w.mat.Stride = c
|
||||
w.capRows = r
|
||||
w.capCols = c
|
||||
return w
|
||||
}
|
||||
|
||||
// putWorkspaceCmplx replaces a used *CDense into the appropriate size
|
||||
// workspace pool. putWorkspace must not be called with a matrix
|
||||
// where references to the underlying data slice have been kept.
|
||||
func putWorkspaceCmplx(w *CDense) {
|
||||
poolCmplx[bits(uint64(cap(w.mat.Data)))].Put(w)
|
||||
}
|
||||
|
@@ -53,7 +53,7 @@ func checkOverlapComplex(a, b cblas128.General) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *CDense) checkOverlapComplex(a cblas128.General) bool {
|
||||
func (m *CDense) checkOverlap(a cblas128.General) bool {
|
||||
return checkOverlapComplex(m.RawCMatrix(), a)
|
||||
}
|
||||
|
||||
@@ -68,5 +68,5 @@ func (m *CDense) checkOverlapMatrix(a CMatrix) bool {
|
||||
case RawCMatrixer:
|
||||
amat = ar.RawCMatrix()
|
||||
}
|
||||
return m.checkOverlapComplex(amat)
|
||||
return m.checkOverlap(amat)
|
||||
}
|
||||
|
Reference in New Issue
Block a user