mat: add non-conjugate transpose for complex matrices

This commit is contained in:
Dan Kortschak
2021-01-07 07:48:49 +10:30
committed by GitHub
parent 9076f1c7d1
commit 1c2011e56d
5 changed files with 335 additions and 38 deletions

View File

@@ -4,7 +4,11 @@
package mat
import "gonum.org/v1/gonum/blas/cblas128"
import (
"math/cmplx"
"gonum.org/v1/gonum/blas/cblas128"
)
var (
cDense *CDense
@@ -29,12 +33,56 @@ func (m *CDense) Dims() (r, c int) {
func (m *CDense) Caps() (r, c int) { return m.capRows, m.capCols }
// H performs an implicit conjugate transpose by returning the receiver inside a
// Conjugate.
// ConjTranspose.
func (m *CDense) H() CMatrix {
return Conjugate{m}
return ConjTranspose{m}
}
// Slice returns a new Matrix that shares backing data with the receiver.
// T performs an implicit transpose by returning the receiver inside a
// CTranspose.
func (m *CDense) T() CMatrix {
return CTranspose{m}
}
// Conj calculates the element-wise conjugate of a and stores the result in the
// receiver.
// Conj will panic if m and a do not have the same dimension unless m is empty.
func (m *CDense) Conj(a CMatrix) {
ar, ac := a.Dims()
aU, aTrans, aConj := untransposeExtractCmplx(a)
m.reuseAsNonZeroed(ar, ac)
if arm, ok := a.(*CDense); ok {
amat := arm.mat
if m != aU {
m.checkOverlap(amat)
}
for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
for i, v := range amat.Data[ja : ja+ac] {
m.mat.Data[i+jm] = cmplx.Conj(v)
}
}
return
}
m.checkOverlapMatrix(aU)
if aTrans != aConj && m == aU {
// Only make workspace if the destination is transposed
// with respect to the source and they are the same
// matrix.
var restore func()
m, restore = m.isolatedWorkspace(aU)
defer restore()
}
for r := 0; r < ar; r++ {
for c := 0; c < ac; c++ {
m.set(r, c, cmplx.Conj(a.At(r, c)))
}
}
}
// Slice returns a new CMatrix that shares backing data with the receiver.
// The returned matrix starts at {i,j} of the receiver and extends k-i rows
// and l-j columns. The final row in the resulting matrix is k-1 and the
// final column is l-1.
@@ -61,7 +109,6 @@ func (m *CDense) slice(i, k, j, l int) *CDense {
return &t
}
// NewCDense creates a new complex Dense matrix with r rows and c columns.
// If data == nil, a new slice is allocated for the backing slice.
// If len(data) == r*c, data is used as the backing slice, and changes to the
@@ -171,6 +218,21 @@ func (m *CDense) reuseAsZeroed(r, c int) {
m.Zero()
}
// isolatedWorkspace returns a new dense matrix w with the size of a and
// returns a callback to defer which performs cleanup at the return of the call.
// This should be used when a method receiver is the same pointer as an input argument.
func (m *CDense) isolatedWorkspace(a CMatrix) (w *CDense, restore func()) {
r, c := a.Dims()
if r == 0 || c == 0 {
panic(ErrZeroLength)
}
w = getWorkspaceCmplx(r, c, false)
return w, func() {
m.Copy(w)
putWorkspaceCmplx(w)
}
}
// Reset zeros the dimensions of the matrix so that it can be reused as the
// receiver of a dimensionally restricted operation.
//
@@ -227,6 +289,14 @@ func (m *CDense) Copy(a CMatrix) (r, c int) {
return r, c
}
// SetRawCMatrix sets the underlying cblas128.General used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in b.
func (m *CDense) SetRawCMatrix(b cblas128.General) {
m.capRows, m.capCols = b.Rows, b.Cols
m.mat = b
}
// RawCMatrix returns the underlying cblas128.General used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in returned cblas128.General.

View File

@@ -4,7 +4,12 @@
package mat
import "testing"
import (
"math/cmplx"
"testing"
"golang.org/x/exp/rand"
)
func TestCDenseNewAtSet(t *testing.T) {
t.Parallel()
@@ -48,6 +53,96 @@ func TestCDenseNewAtSet(t *testing.T) {
}
}
func TestCDenseConjElem(t *testing.T) {
t.Parallel()
rnd := rand.New(rand.NewSource(1))
for r := 1; r <= 8; r++ {
for c := 1; c <= 8; c++ {
const (
empty = iota
fit
sliced
self
)
for _, dst := range []int{empty, fit, sliced, self} {
const (
noTrans = iota
trans
conjTrans
bothHT
bothTH
)
for _, src := range []int{noTrans, trans, conjTrans, bothHT, bothTH} {
d := NewCDense(r, c, nil)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
d.Set(i, j, complex(rnd.NormFloat64(), rnd.NormFloat64()))
}
}
var (
a CMatrix
op string
)
switch src {
case noTrans:
a = d
case trans:
r, c = c, r
a = d.T()
op = ".T"
case conjTrans:
r, c = c, r
a = d.H()
op = ".H"
case bothHT:
a = d.H().T()
op = ".H.T"
case bothTH:
a = d.T().H()
op = ".T.H"
default:
panic("invalid src op")
}
aCopy := NewCDense(r, c, nil)
aCopy.Copy(a)
var got *CDense
switch dst {
case empty:
got = &CDense{}
case fit:
got = NewCDense(r, c, nil)
case sliced:
got = NewCDense(r*2, c*2, nil).Slice(1, r+1, 1, c+1).(*CDense)
case self:
if r != c && (src == conjTrans || src == trans) {
continue
}
got = d
default:
panic("invalid dst size")
}
got.Conj(a)
for i := 0; i < r; i++ {
for j := 0; j < c; j++ {
if got.At(i, j) != cmplx.Conj(aCopy.At(i, j)) {
t.Errorf("unexpected results a%s[%d, %d] for r=%d c=%d %v != %v",
op, i, j, r, c, got.At(i, j), cmplx.Conj(a.At(i, j)),
)
}
}
}
}
}
}
}
}
func TestCDenseGrow(t *testing.T) {
t.Parallel()
m := &CDense{}

View File

@@ -14,18 +14,24 @@ import (
// CMatrix is the basic matrix interface type for complex matrices.
type CMatrix interface {
// Dims returns the dimensions of a Matrix.
// Dims returns the dimensions of a CMatrix.
Dims() (r, c int)
// At returns the value of a matrix element at row i, column j.
// It will panic if i or j are out of bounds for the matrix.
At(i, j int) complex128
// H returns the conjugate transpose of the Matrix. Whether H
// H returns the conjugate transpose of the CMatrix. Whether H
// returns a copy of the underlying data is implementation dependent.
// This method may be implemented using the Conjugate type, which
// This method may be implemented using the ConjTranspose type, which
// provides an implicit matrix conjugate transpose.
H() CMatrix
// T returns the transpose of the CMatrix. Whether T returns a copy of the
// underlying data is implementation dependent.
// This method may be implemented using the CTranspose type, which
// provides an implicit matrix transpose.
T() CMatrix
}
// A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
@@ -35,53 +41,109 @@ type RawCMatrixer interface {
}
var (
_ CMatrix = Conjugate{}
_ Unconjugator = Conjugate{}
_ CMatrix = ConjTranspose{}
_ UnConjTransposer = ConjTranspose{}
)
// Conjugate is a type for performing an implicit matrix conjugate transpose.
// It implements the Matrix interface, returning values from the conjugate
// ConjTranspose is a type for performing an implicit matrix conjugate transpose.
// It implements the CMatrix interface, returning values from the conjugate
// transpose of the matrix within.
type Conjugate struct {
type ConjTranspose struct {
CMatrix CMatrix
}
// At returns the value of the element at row i and column j of the conjugate
// transposed matrix, that is, row j and column i of the Matrix field.
func (t Conjugate) At(i, j int) complex128 {
// transposed matrix, that is, row j and column i of the CMatrix field.
func (t ConjTranspose) At(i, j int) complex128 {
z := t.CMatrix.At(j, i)
return cmplx.Conj(z)
}
// Dims returns the dimensions of the transposed matrix. The number of rows returned
// is the number of columns in the Matrix field, and the number of columns is
// the number of rows in the Matrix field.
func (t Conjugate) Dims() (r, c int) {
// is the number of columns in the CMatrix field, and the number of columns is
// the number of rows in the CMatrix field.
func (t ConjTranspose) Dims() (r, c int) {
c, r = t.CMatrix.Dims()
return r, c
}
// H performs an implicit conjugate transpose by returning the Matrix field.
func (t Conjugate) H() CMatrix {
// H performs an implicit conjugate transpose by returning the CMatrix field.
func (t ConjTranspose) H() CMatrix {
return t.CMatrix
}
// Unconjugate returns the Matrix field.
func (t Conjugate) Unconjugate() CMatrix {
// T performs an implicit transpose by returning the receiver inside a
// CTranspose.
func (t ConjTranspose) T() CMatrix {
return CTranspose{t}
}
// UnConjTranspose returns the CMatrix field.
func (t ConjTranspose) UnConjTranspose() CMatrix {
return t.CMatrix
}
// Unconjugator is a type that can undo an implicit conjugate transpose.
type Unconjugator interface {
// CTranspose is a type for performing an implicit matrix conjugate transpose.
// It implements the CMatrix interface, returning values from the conjugate
// transpose of the matrix within.
type CTranspose struct {
CMatrix CMatrix
}
// At returns the value of the element at row i and column j of the conjugate
// transposed matrix, that is, row j and column i of the CMatrix field.
func (t CTranspose) At(i, j int) complex128 {
return t.CMatrix.At(j, i)
}
// Dims returns the dimensions of the transposed matrix. The number of rows returned
// is the number of columns in the CMatrix field, and the number of columns is
// the number of rows in the CMatrix field.
func (t CTranspose) Dims() (r, c int) {
c, r = t.CMatrix.Dims()
return r, c
}
// H performs an implicit transpose by returning the receiver inside a
// ConjTranspose.
func (t CTranspose) H() CMatrix {
return ConjTranspose{t}
}
// T performs an implicit conjugate transpose by returning the CMatrix field.
func (t CTranspose) T() CMatrix {
return t.CMatrix
}
// Untranspose returns the CMatrix field.
func (t CTranspose) Untranspose() CMatrix {
return t.CMatrix
}
// UnConjTransposer is a type that can undo an implicit conjugate transpose.
type UnConjTransposer interface {
// UnConjTranspose returns the underlying CMatrix stored for the implicit
// conjugate transpose.
UnConjTranspose() CMatrix
// Note: This interface is needed to unify all of the Conjugate types. In
// the cmat128 methods, we need to test if the Matrix has been implicitly
// the cmat128 methods, we need to test if the CMatrix has been implicitly
// transposed. If this is checked by testing for the specific Conjugate type
// then the behavior will be different if the user uses H() or HTri() for a
// triangular matrix.
}
// Unconjugate returns the underlying Matrix stored for the implicit
// conjugate transpose.
Unconjugate() CMatrix
// CUntransposer is a type that can undo an implicit transpose.
type CUntransposer interface {
// Untranspose returns the underlying CMatrix stored for the implicit
// transpose.
Untranspose() CMatrix
// Note: This interface is needed to unify all of the CTranspose types. In
// the cmat128 methods, we need to test if the CMatrix has been implicitly
// transposed. If this is checked by testing for the specific CTranspose type
// then the behavior will be different if the user uses T() or TTri() for a
// triangular matrix.
}
// useC returns a complex128 slice with l elements, using c if it
@@ -112,14 +174,49 @@ func zeroC(c []complex128) {
}
}
// unconjugate unconjugates a matrix if applicable. If a is an Unconjugator, then
// unconjugate returns the underlying matrix and true. If it is not, then it returns
// the input matrix and false.
func unconjugate(a CMatrix) (CMatrix, bool) {
if ut, ok := a.(Unconjugator); ok {
return ut.Unconjugate(), true
// untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer
// or an UnConjTransposer, then untranspose returns the underlying matrix and true for
// the kind of transpose (potentially both).
// If it is not, then it returns the input matrix and false for trans and conj.
func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
switch ut := a.(type) {
case CUntransposer:
trans = true
u := ut.Untranspose()
if uc, ok := u.(UnConjTransposer); ok {
return uc.UnConjTranspose(), trans, true
}
return u, trans, false
case UnConjTransposer:
conj = true
u := ut.UnConjTranspose()
if ut, ok := u.(CUntransposer); ok {
return ut.Untranspose(), true, conj
}
return u, false, conj
default:
return a, false, false
}
}
// untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type.
//
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
// Otherwise, if it implements a Raw method, an appropriate built-in type value
// is returned holding the raw matrix value of the input. If neither of these
// is possible, the untransposed matrix is returned.
func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
ut, trans, conj := untransposeCmplx(a)
switch m := ut.(type) {
case *CDense:
return m, trans, conj
case RawCMatrixer:
var d CDense
d.SetRawCMatrix(m.RawCMatrix())
return &d, trans, conj
default:
return ut, trans, conj
}
return a, false
}
// CEqual returns whether the matrices a and b have the same size

View File

@@ -9,6 +9,7 @@ import (
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/blas/cblas128"
)
var tab64 = [64]byte{
@@ -59,6 +60,9 @@ var (
// poolInts is the []int equivalent of pool.
poolInts [63]sync.Pool
// poolCmplx is the CDense equivalent of pool.
poolCmplx [63]sync.Pool
)
func init() {
@@ -94,6 +98,12 @@ func init() {
s := make([]int, l)
return &s
}
poolCmplx[i].New = func() interface{} {
return &CDense{mat: cblas128.General{
Data: make([]complex128, l),
}}
}
}
}
@@ -236,3 +246,28 @@ func getInts(l int, clear bool) []int {
func putInts(w []int) {
poolInts[bits(uint64(cap(w)))].Put(&w)
}
// getWorkspaceCmplx returns a *CDense of size r×c and a data slice
// with a cap that is less than 2*r*c. If clear is true, the
// data slice visible through the CMatrix interface is zeroed.
func getWorkspaceCmplx(r, c int, clear bool) *CDense {
l := uint64(r * c)
w := poolCmplx[bits(l)].Get().(*CDense)
w.mat.Data = w.mat.Data[:l]
if clear {
zeroC(w.mat.Data)
}
w.mat.Rows = r
w.mat.Cols = c
w.mat.Stride = c
w.capRows = r
w.capCols = c
return w
}
// putWorkspaceCmplx replaces a used *CDense into the appropriate size
// workspace pool. putWorkspace must not be called with a matrix
// where references to the underlying data slice have been kept.
func putWorkspaceCmplx(w *CDense) {
poolCmplx[bits(uint64(cap(w.mat.Data)))].Put(w)
}

View File

@@ -53,7 +53,7 @@ func checkOverlapComplex(a, b cblas128.General) bool {
return false
}
func (m *CDense) checkOverlapComplex(a cblas128.General) bool {
func (m *CDense) checkOverlap(a cblas128.General) bool {
return checkOverlapComplex(m.RawCMatrix(), a)
}
@@ -68,5 +68,5 @@ func (m *CDense) checkOverlapMatrix(a CMatrix) bool {
case RawCMatrixer:
amat = ar.RawCMatrix()
}
return m.checkOverlapComplex(amat)
return m.checkOverlap(amat)
}