diff --git a/mat/cdense.go b/mat/cdense.go index 2c51fba1..02c4109d 100644 --- a/mat/cdense.go +++ b/mat/cdense.go @@ -4,7 +4,11 @@ package mat -import "gonum.org/v1/gonum/blas/cblas128" +import ( + "math/cmplx" + + "gonum.org/v1/gonum/blas/cblas128" +) var ( cDense *CDense @@ -29,12 +33,56 @@ func (m *CDense) Dims() (r, c int) { func (m *CDense) Caps() (r, c int) { return m.capRows, m.capCols } // H performs an implicit conjugate transpose by returning the receiver inside a -// Conjugate. +// ConjTranspose. func (m *CDense) H() CMatrix { - return Conjugate{m} + return ConjTranspose{m} } -// Slice returns a new Matrix that shares backing data with the receiver. +// T performs an implicit transpose by returning the receiver inside a +// CTranspose. +func (m *CDense) T() CMatrix { + return CTranspose{m} +} + +// Conj calculates the element-wise conjugate of a and stores the result in the +// receiver. +// Conj will panic if m and a do not have the same dimension unless m is empty. +func (m *CDense) Conj(a CMatrix) { + ar, ac := a.Dims() + aU, aTrans, aConj := untransposeExtractCmplx(a) + m.reuseAsNonZeroed(ar, ac) + + if arm, ok := a.(*CDense); ok { + amat := arm.mat + if m != aU { + m.checkOverlap(amat) + } + for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride { + for i, v := range amat.Data[ja : ja+ac] { + m.mat.Data[i+jm] = cmplx.Conj(v) + } + } + return + } + + m.checkOverlapMatrix(aU) + if aTrans != aConj && m == aU { + // Only make workspace if the destination is transposed + // with respect to the source and they are the same + // matrix. + var restore func() + m, restore = m.isolatedWorkspace(aU) + defer restore() + } + + for r := 0; r < ar; r++ { + for c := 0; c < ac; c++ { + m.set(r, c, cmplx.Conj(a.At(r, c))) + } + } +} + +// Slice returns a new CMatrix that shares backing data with the receiver. // The returned matrix starts at {i,j} of the receiver and extends k-i rows // and l-j columns. The final row in the resulting matrix is k-1 and the // final column is l-1. @@ -61,7 +109,6 @@ func (m *CDense) slice(i, k, j, l int) *CDense { return &t } - // NewCDense creates a new complex Dense matrix with r rows and c columns. // If data == nil, a new slice is allocated for the backing slice. // If len(data) == r*c, data is used as the backing slice, and changes to the @@ -171,6 +218,21 @@ func (m *CDense) reuseAsZeroed(r, c int) { m.Zero() } +// isolatedWorkspace returns a new dense matrix w with the size of a and +// returns a callback to defer which performs cleanup at the return of the call. +// This should be used when a method receiver is the same pointer as an input argument. +func (m *CDense) isolatedWorkspace(a CMatrix) (w *CDense, restore func()) { + r, c := a.Dims() + if r == 0 || c == 0 { + panic(ErrZeroLength) + } + w = getWorkspaceCmplx(r, c, false) + return w, func() { + m.Copy(w) + putWorkspaceCmplx(w) + } +} + // Reset zeros the dimensions of the matrix so that it can be reused as the // receiver of a dimensionally restricted operation. // @@ -227,6 +289,14 @@ func (m *CDense) Copy(a CMatrix) (r, c int) { return r, c } +// SetRawCMatrix sets the underlying cblas128.General used by the receiver. +// Changes to elements in the receiver following the call will be reflected +// in b. +func (m *CDense) SetRawCMatrix(b cblas128.General) { + m.capRows, m.capCols = b.Rows, b.Cols + m.mat = b +} + // RawCMatrix returns the underlying cblas128.General used by the receiver. // Changes to elements in the receiver following the call will be reflected // in returned cblas128.General. diff --git a/mat/cdense_test.go b/mat/cdense_test.go index 436550c0..d5440638 100644 --- a/mat/cdense_test.go +++ b/mat/cdense_test.go @@ -4,7 +4,12 @@ package mat -import "testing" +import ( + "math/cmplx" + "testing" + + "golang.org/x/exp/rand" +) func TestCDenseNewAtSet(t *testing.T) { t.Parallel() @@ -48,6 +53,96 @@ func TestCDenseNewAtSet(t *testing.T) { } } +func TestCDenseConjElem(t *testing.T) { + t.Parallel() + + rnd := rand.New(rand.NewSource(1)) + + for r := 1; r <= 8; r++ { + for c := 1; c <= 8; c++ { + const ( + empty = iota + fit + sliced + self + ) + for _, dst := range []int{empty, fit, sliced, self} { + const ( + noTrans = iota + trans + conjTrans + bothHT + bothTH + ) + for _, src := range []int{noTrans, trans, conjTrans, bothHT, bothTH} { + d := NewCDense(r, c, nil) + for i := 0; i < r; i++ { + for j := 0; j < c; j++ { + d.Set(i, j, complex(rnd.NormFloat64(), rnd.NormFloat64())) + } + } + + var ( + a CMatrix + op string + ) + switch src { + case noTrans: + a = d + case trans: + r, c = c, r + a = d.T() + op = ".T" + case conjTrans: + r, c = c, r + a = d.H() + op = ".H" + case bothHT: + a = d.H().T() + op = ".H.T" + case bothTH: + a = d.T().H() + op = ".T.H" + default: + panic("invalid src op") + } + aCopy := NewCDense(r, c, nil) + aCopy.Copy(a) + + var got *CDense + switch dst { + case empty: + got = &CDense{} + case fit: + got = NewCDense(r, c, nil) + case sliced: + got = NewCDense(r*2, c*2, nil).Slice(1, r+1, 1, c+1).(*CDense) + case self: + if r != c && (src == conjTrans || src == trans) { + continue + } + got = d + default: + panic("invalid dst size") + } + + got.Conj(a) + + for i := 0; i < r; i++ { + for j := 0; j < c; j++ { + if got.At(i, j) != cmplx.Conj(aCopy.At(i, j)) { + t.Errorf("unexpected results a%s[%d, %d] for r=%d c=%d %v != %v", + op, i, j, r, c, got.At(i, j), cmplx.Conj(a.At(i, j)), + ) + } + } + } + } + } + } + } +} + func TestCDenseGrow(t *testing.T) { t.Parallel() m := &CDense{} diff --git a/mat/cmatrix.go b/mat/cmatrix.go index c805cfa2..33664575 100644 --- a/mat/cmatrix.go +++ b/mat/cmatrix.go @@ -14,18 +14,24 @@ import ( // CMatrix is the basic matrix interface type for complex matrices. type CMatrix interface { - // Dims returns the dimensions of a Matrix. + // Dims returns the dimensions of a CMatrix. Dims() (r, c int) // At returns the value of a matrix element at row i, column j. // It will panic if i or j are out of bounds for the matrix. At(i, j int) complex128 - // H returns the conjugate transpose of the Matrix. Whether H + // H returns the conjugate transpose of the CMatrix. Whether H // returns a copy of the underlying data is implementation dependent. - // This method may be implemented using the Conjugate type, which + // This method may be implemented using the ConjTranspose type, which // provides an implicit matrix conjugate transpose. H() CMatrix + + // T returns the transpose of the CMatrix. Whether T returns a copy of the + // underlying data is implementation dependent. + // This method may be implemented using the CTranspose type, which + // provides an implicit matrix transpose. + T() CMatrix } // A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data @@ -35,53 +41,109 @@ type RawCMatrixer interface { } var ( - _ CMatrix = Conjugate{} - _ Unconjugator = Conjugate{} + _ CMatrix = ConjTranspose{} + _ UnConjTransposer = ConjTranspose{} ) -// Conjugate is a type for performing an implicit matrix conjugate transpose. -// It implements the Matrix interface, returning values from the conjugate +// ConjTranspose is a type for performing an implicit matrix conjugate transpose. +// It implements the CMatrix interface, returning values from the conjugate // transpose of the matrix within. -type Conjugate struct { +type ConjTranspose struct { CMatrix CMatrix } // At returns the value of the element at row i and column j of the conjugate -// transposed matrix, that is, row j and column i of the Matrix field. -func (t Conjugate) At(i, j int) complex128 { +// transposed matrix, that is, row j and column i of the CMatrix field. +func (t ConjTranspose) At(i, j int) complex128 { z := t.CMatrix.At(j, i) return cmplx.Conj(z) } // Dims returns the dimensions of the transposed matrix. The number of rows returned -// is the number of columns in the Matrix field, and the number of columns is -// the number of rows in the Matrix field. -func (t Conjugate) Dims() (r, c int) { +// is the number of columns in the CMatrix field, and the number of columns is +// the number of rows in the CMatrix field. +func (t ConjTranspose) Dims() (r, c int) { c, r = t.CMatrix.Dims() return r, c } -// H performs an implicit conjugate transpose by returning the Matrix field. -func (t Conjugate) H() CMatrix { +// H performs an implicit conjugate transpose by returning the CMatrix field. +func (t ConjTranspose) H() CMatrix { return t.CMatrix } -// Unconjugate returns the Matrix field. -func (t Conjugate) Unconjugate() CMatrix { +// T performs an implicit transpose by returning the receiver inside a +// CTranspose. +func (t ConjTranspose) T() CMatrix { + return CTranspose{t} +} + +// UnConjTranspose returns the CMatrix field. +func (t ConjTranspose) UnConjTranspose() CMatrix { return t.CMatrix } -// Unconjugator is a type that can undo an implicit conjugate transpose. -type Unconjugator interface { +// CTranspose is a type for performing an implicit matrix conjugate transpose. +// It implements the CMatrix interface, returning values from the conjugate +// transpose of the matrix within. +type CTranspose struct { + CMatrix CMatrix +} + +// At returns the value of the element at row i and column j of the conjugate +// transposed matrix, that is, row j and column i of the CMatrix field. +func (t CTranspose) At(i, j int) complex128 { + return t.CMatrix.At(j, i) +} + +// Dims returns the dimensions of the transposed matrix. The number of rows returned +// is the number of columns in the CMatrix field, and the number of columns is +// the number of rows in the CMatrix field. +func (t CTranspose) Dims() (r, c int) { + c, r = t.CMatrix.Dims() + return r, c +} + +// H performs an implicit transpose by returning the receiver inside a +// ConjTranspose. +func (t CTranspose) H() CMatrix { + return ConjTranspose{t} +} + +// T performs an implicit conjugate transpose by returning the CMatrix field. +func (t CTranspose) T() CMatrix { + return t.CMatrix +} + +// Untranspose returns the CMatrix field. +func (t CTranspose) Untranspose() CMatrix { + return t.CMatrix +} + +// UnConjTransposer is a type that can undo an implicit conjugate transpose. +type UnConjTransposer interface { + // UnConjTranspose returns the underlying CMatrix stored for the implicit + // conjugate transpose. + UnConjTranspose() CMatrix + // Note: This interface is needed to unify all of the Conjugate types. In - // the cmat128 methods, we need to test if the Matrix has been implicitly + // the cmat128 methods, we need to test if the CMatrix has been implicitly // transposed. If this is checked by testing for the specific Conjugate type // then the behavior will be different if the user uses H() or HTri() for a // triangular matrix. +} - // Unconjugate returns the underlying Matrix stored for the implicit - // conjugate transpose. - Unconjugate() CMatrix +// CUntransposer is a type that can undo an implicit transpose. +type CUntransposer interface { + // Untranspose returns the underlying CMatrix stored for the implicit + // transpose. + Untranspose() CMatrix + + // Note: This interface is needed to unify all of the CTranspose types. In + // the cmat128 methods, we need to test if the CMatrix has been implicitly + // transposed. If this is checked by testing for the specific CTranspose type + // then the behavior will be different if the user uses T() or TTri() for a + // triangular matrix. } // useC returns a complex128 slice with l elements, using c if it @@ -112,14 +174,49 @@ func zeroC(c []complex128) { } } -// unconjugate unconjugates a matrix if applicable. If a is an Unconjugator, then -// unconjugate returns the underlying matrix and true. If it is not, then it returns -// the input matrix and false. -func unconjugate(a CMatrix) (CMatrix, bool) { - if ut, ok := a.(Unconjugator); ok { - return ut.Unconjugate(), true +// untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer +// or an UnConjTransposer, then untranspose returns the underlying matrix and true for +// the kind of transpose (potentially both). +// If it is not, then it returns the input matrix and false for trans and conj. +func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) { + switch ut := a.(type) { + case CUntransposer: + trans = true + u := ut.Untranspose() + if uc, ok := u.(UnConjTransposer); ok { + return uc.UnConjTranspose(), trans, true + } + return u, trans, false + case UnConjTransposer: + conj = true + u := ut.UnConjTranspose() + if ut, ok := u.(CUntransposer); ok { + return ut.Untranspose(), true, conj + } + return u, false, conj + default: + return a, false, false + } +} + +// untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type. +// +// The untransposed matrix is returned unaltered if it is a built-in matrix type. +// Otherwise, if it implements a Raw method, an appropriate built-in type value +// is returned holding the raw matrix value of the input. If neither of these +// is possible, the untransposed matrix is returned. +func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) { + ut, trans, conj := untransposeCmplx(a) + switch m := ut.(type) { + case *CDense: + return m, trans, conj + case RawCMatrixer: + var d CDense + d.SetRawCMatrix(m.RawCMatrix()) + return &d, trans, conj + default: + return ut, trans, conj } - return a, false } // CEqual returns whether the matrices a and b have the same size diff --git a/mat/pool.go b/mat/pool.go index f51215ce..634382bd 100644 --- a/mat/pool.go +++ b/mat/pool.go @@ -9,6 +9,7 @@ import ( "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/blas/blas64" + "gonum.org/v1/gonum/blas/cblas128" ) var tab64 = [64]byte{ @@ -59,6 +60,9 @@ var ( // poolInts is the []int equivalent of pool. poolInts [63]sync.Pool + + // poolCmplx is the CDense equivalent of pool. + poolCmplx [63]sync.Pool ) func init() { @@ -94,6 +98,12 @@ func init() { s := make([]int, l) return &s } + + poolCmplx[i].New = func() interface{} { + return &CDense{mat: cblas128.General{ + Data: make([]complex128, l), + }} + } } } @@ -236,3 +246,28 @@ func getInts(l int, clear bool) []int { func putInts(w []int) { poolInts[bits(uint64(cap(w)))].Put(&w) } + +// getWorkspaceCmplx returns a *CDense of size r×c and a data slice +// with a cap that is less than 2*r*c. If clear is true, the +// data slice visible through the CMatrix interface is zeroed. +func getWorkspaceCmplx(r, c int, clear bool) *CDense { + l := uint64(r * c) + w := poolCmplx[bits(l)].Get().(*CDense) + w.mat.Data = w.mat.Data[:l] + if clear { + zeroC(w.mat.Data) + } + w.mat.Rows = r + w.mat.Cols = c + w.mat.Stride = c + w.capRows = r + w.capCols = c + return w +} + +// putWorkspaceCmplx replaces a used *CDense into the appropriate size +// workspace pool. putWorkspace must not be called with a matrix +// where references to the underlying data slice have been kept. +func putWorkspaceCmplx(w *CDense) { + poolCmplx[bits(uint64(cap(w.mat.Data)))].Put(w) +} diff --git a/mat/shadow_complex.go b/mat/shadow_complex.go index 7bf1cefd..1a3f3fc2 100644 --- a/mat/shadow_complex.go +++ b/mat/shadow_complex.go @@ -53,7 +53,7 @@ func checkOverlapComplex(a, b cblas128.General) bool { return false } -func (m *CDense) checkOverlapComplex(a cblas128.General) bool { +func (m *CDense) checkOverlap(a cblas128.General) bool { return checkOverlapComplex(m.RawCMatrix(), a) } @@ -68,5 +68,5 @@ func (m *CDense) checkOverlapMatrix(a CMatrix) bool { case RawCMatrixer: amat = ar.RawCMatrix() } - return m.checkOverlapComplex(amat) + return m.checkOverlap(amat) }