mat: add non-conjugate transpose for complex matrices

This commit is contained in:
Dan Kortschak
2021-01-07 07:48:49 +10:30
committed by GitHub
parent 9076f1c7d1
commit 1c2011e56d
5 changed files with 335 additions and 38 deletions

View File

@@ -14,18 +14,24 @@ import (
// CMatrix is the basic matrix interface type for complex matrices.
type CMatrix interface {
// Dims returns the dimensions of a Matrix.
// Dims returns the dimensions of a CMatrix.
Dims() (r, c int)
// At returns the value of a matrix element at row i, column j.
// It will panic if i or j are out of bounds for the matrix.
At(i, j int) complex128
// H returns the conjugate transpose of the Matrix. Whether H
// H returns the conjugate transpose of the CMatrix. Whether H
// returns a copy of the underlying data is implementation dependent.
// This method may be implemented using the Conjugate type, which
// This method may be implemented using the ConjTranspose type, which
// provides an implicit matrix conjugate transpose.
H() CMatrix
// T returns the transpose of the CMatrix. Whether T returns a copy of the
// underlying data is implementation dependent.
// This method may be implemented using the CTranspose type, which
// provides an implicit matrix transpose.
T() CMatrix
}
// A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
@@ -35,53 +41,109 @@ type RawCMatrixer interface {
}
var (
_ CMatrix = Conjugate{}
_ Unconjugator = Conjugate{}
_ CMatrix = ConjTranspose{}
_ UnConjTransposer = ConjTranspose{}
)
// Conjugate is a type for performing an implicit matrix conjugate transpose.
// It implements the Matrix interface, returning values from the conjugate
// ConjTranspose is a type for performing an implicit matrix conjugate transpose.
// It implements the CMatrix interface, returning values from the conjugate
// transpose of the matrix within.
type Conjugate struct {
type ConjTranspose struct {
CMatrix CMatrix
}
// At returns the value of the element at row i and column j of the conjugate
// transposed matrix, that is, row j and column i of the Matrix field.
func (t Conjugate) At(i, j int) complex128 {
// transposed matrix, that is, row j and column i of the CMatrix field.
func (t ConjTranspose) At(i, j int) complex128 {
z := t.CMatrix.At(j, i)
return cmplx.Conj(z)
}
// Dims returns the dimensions of the transposed matrix. The number of rows returned
// is the number of columns in the Matrix field, and the number of columns is
// the number of rows in the Matrix field.
func (t Conjugate) Dims() (r, c int) {
// is the number of columns in the CMatrix field, and the number of columns is
// the number of rows in the CMatrix field.
func (t ConjTranspose) Dims() (r, c int) {
c, r = t.CMatrix.Dims()
return r, c
}
// H performs an implicit conjugate transpose by returning the Matrix field.
func (t Conjugate) H() CMatrix {
// H performs an implicit conjugate transpose by returning the CMatrix field.
func (t ConjTranspose) H() CMatrix {
return t.CMatrix
}
// Unconjugate returns the Matrix field.
func (t Conjugate) Unconjugate() CMatrix {
// T performs an implicit transpose by returning the receiver inside a
// CTranspose.
func (t ConjTranspose) T() CMatrix {
return CTranspose{t}
}
// UnConjTranspose returns the CMatrix field.
func (t ConjTranspose) UnConjTranspose() CMatrix {
return t.CMatrix
}
// Unconjugator is a type that can undo an implicit conjugate transpose.
type Unconjugator interface {
// CTranspose is a type for performing an implicit matrix conjugate transpose.
// It implements the CMatrix interface, returning values from the conjugate
// transpose of the matrix within.
type CTranspose struct {
CMatrix CMatrix
}
// At returns the value of the element at row i and column j of the conjugate
// transposed matrix, that is, row j and column i of the CMatrix field.
func (t CTranspose) At(i, j int) complex128 {
return t.CMatrix.At(j, i)
}
// Dims returns the dimensions of the transposed matrix. The number of rows returned
// is the number of columns in the CMatrix field, and the number of columns is
// the number of rows in the CMatrix field.
func (t CTranspose) Dims() (r, c int) {
c, r = t.CMatrix.Dims()
return r, c
}
// H performs an implicit transpose by returning the receiver inside a
// ConjTranspose.
func (t CTranspose) H() CMatrix {
return ConjTranspose{t}
}
// T performs an implicit conjugate transpose by returning the CMatrix field.
func (t CTranspose) T() CMatrix {
return t.CMatrix
}
// Untranspose returns the CMatrix field.
func (t CTranspose) Untranspose() CMatrix {
return t.CMatrix
}
// UnConjTransposer is a type that can undo an implicit conjugate transpose.
type UnConjTransposer interface {
// UnConjTranspose returns the underlying CMatrix stored for the implicit
// conjugate transpose.
UnConjTranspose() CMatrix
// Note: This interface is needed to unify all of the Conjugate types. In
// the cmat128 methods, we need to test if the Matrix has been implicitly
// the cmat128 methods, we need to test if the CMatrix has been implicitly
// transposed. If this is checked by testing for the specific Conjugate type
// then the behavior will be different if the user uses H() or HTri() for a
// triangular matrix.
}
// Unconjugate returns the underlying Matrix stored for the implicit
// conjugate transpose.
Unconjugate() CMatrix
// CUntransposer is a type that can undo an implicit transpose.
type CUntransposer interface {
// Untranspose returns the underlying CMatrix stored for the implicit
// transpose.
Untranspose() CMatrix
// Note: This interface is needed to unify all of the CTranspose types. In
// the cmat128 methods, we need to test if the CMatrix has been implicitly
// transposed. If this is checked by testing for the specific CTranspose type
// then the behavior will be different if the user uses T() or TTri() for a
// triangular matrix.
}
// useC returns a complex128 slice with l elements, using c if it
@@ -112,14 +174,49 @@ func zeroC(c []complex128) {
}
}
// unconjugate unconjugates a matrix if applicable. If a is an Unconjugator, then
// unconjugate returns the underlying matrix and true. If it is not, then it returns
// the input matrix and false.
func unconjugate(a CMatrix) (CMatrix, bool) {
if ut, ok := a.(Unconjugator); ok {
return ut.Unconjugate(), true
// untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer
// or an UnConjTransposer, then untranspose returns the underlying matrix and true for
// the kind of transpose (potentially both).
// If it is not, then it returns the input matrix and false for trans and conj.
func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
switch ut := a.(type) {
case CUntransposer:
trans = true
u := ut.Untranspose()
if uc, ok := u.(UnConjTransposer); ok {
return uc.UnConjTranspose(), trans, true
}
return u, trans, false
case UnConjTransposer:
conj = true
u := ut.UnConjTranspose()
if ut, ok := u.(CUntransposer); ok {
return ut.Untranspose(), true, conj
}
return u, false, conj
default:
return a, false, false
}
}
// untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type.
//
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
// Otherwise, if it implements a Raw method, an appropriate built-in type value
// is returned holding the raw matrix value of the input. If neither of these
// is possible, the untransposed matrix is returned.
func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
ut, trans, conj := untransposeCmplx(a)
switch m := ut.(type) {
case *CDense:
return m, trans, conj
case RawCMatrixer:
var d CDense
d.SetRawCMatrix(m.RawCMatrix())
return &d, trans, conj
default:
return ut, trans, conj
}
return a, false
}
// CEqual returns whether the matrices a and b have the same size