diff --git a/mat/band.go b/mat/band.go index 20225109..72ebefdd 100644 --- a/mat/band.go +++ b/mat/band.go @@ -188,6 +188,13 @@ func (b *BandDense) RawBand() blas64.Band { return b.mat } +// SetRawBand sets the underlying blas64.Band used by the receiver. +// Changes to elements in the receiver following the call will be reflected +// in the input. +func (b *BandDense) SetRawBand(mat blas64.Band) { + b.mat = mat +} + // DiagView returns the diagonal as a matrix backed by the original data. func (b *BandDense) DiagView() Diagonal { n := min(b.mat.Rows, b.mat.Cols) diff --git a/mat/dense.go b/mat/dense.go index be4ba8bb..87b1105c 100644 --- a/mat/dense.go +++ b/mat/dense.go @@ -134,16 +134,6 @@ func (m *Dense) Zero() { } } -// untranspose untransposes a matrix if applicable. If a is an Untransposer, then -// untranspose returns the underlying matrix and true. If it is not, then it returns -// the input matrix and false. -func untranspose(a Matrix) (Matrix, bool) { - if ut, ok := a.(Untransposer); ok { - return ut.Untranspose(), true - } - return a, false -} - // isolatedWorkspace returns a new dense matrix w with the size of a and // returns a callback to defer which performs cleanup at the return of the call. // This should be used when a method receiver is the same pointer as an input argument. @@ -552,3 +542,17 @@ func (m *Dense) Augment(a, b Matrix) { w := m.Slice(0, br, ac, ac+bc).(*Dense) w.Copy(b) } + +// Trace returns the trace of the matrix. The matrix must be square or Trace +// will panic. +func (m *Dense) Trace() float64 { + if m.mat.Rows != m.mat.Cols { + panic(ErrSquare) + } + // TODO(btracey): could use internal asm sum routine. + var v float64 + for i := 0; i < m.mat.Rows; i++ { + v += m.mat.Data[i*m.mat.Stride+i] + } + return v +} diff --git a/mat/matrix.go b/mat/matrix.go index df217d9d..444d0445 100644 --- a/mat/matrix.go +++ b/mat/matrix.go @@ -207,6 +207,73 @@ type ColNonZeroDoer interface { DoColNonZero(j int, fn func(i, j int, v float64)) } +// untranspose untransposes a matrix if applicable. If a is an Untransposer, then +// untranspose returns the underlying matrix and true. If it is not, then it returns +// the input matrix and false. +func untranspose(a Matrix) (Matrix, bool) { + if ut, ok := a.(Untransposer); ok { + return ut.Untranspose(), true + } + return a, false +} + +// untransposeExtract returns an untransposed matrix in a built-in matrix type. +// +// The untransposed matrix is returned unaltered if it is a built-in matrix type. +// Otherwise, if it implements a Raw method, an appropriate built-in type value +// is returned holding the raw matrix value of the input. If neither of these +// is possible, the untransposed matrix is returned. +func untransposeExtract(a Matrix) (Matrix, bool) { + ut, trans := untranspose(a) + switch m := ut.(type) { + case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense: + return m, trans + // TODO(btracey): Add here if we ever have an equivalent of RawDiagDense. + case RawSymBander: + rsb := m.RawSymBand() + if rsb.Uplo != blas.Upper { + return ut, trans + } + var sb SymBandDense + sb.SetRawSymBand(rsb) + return &sb, trans + case RawTriBander: + rtb := m.RawTriBand() + if rtb.Diag == blas.Unit { + return ut, trans + } + var tb TriBandDense + tb.SetRawTriBand(rtb) + return &tb, trans + case RawBander: + var b BandDense + b.SetRawBand(m.RawBand()) + return &b, trans + case RawTriangular: + rt := m.RawTriangular() + if rt.Diag == blas.Unit { + return ut, trans + } + var t TriDense + t.SetRawTriangular(rt) + return &t, trans + case RawSymmetricer: + rs := m.RawSymmetric() + if rs.Uplo != blas.Upper { + return ut, trans + } + var s SymDense + s.SetRawSymmetric(rs) + return &s, trans + case RawMatrixer: + var d Dense + d.SetRawMatrix(m.RawMatrix()) + return &d, trans + default: + return ut, trans + } +} + // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful. // TODO(btracey): Add in fast paths to Row/Col for the other concrete types // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.) @@ -803,44 +870,28 @@ func Sum(a Matrix) float64 { return sum } +// A Tracer can compute the trace of the matrix. Trace must panic if the +// matrix is not square. +type Tracer interface { + Trace() float64 +} + // Trace returns the trace of the matrix. Trace will panic if the // matrix is not square. func Trace(a Matrix) float64 { + m, _ := untransposeExtract(a) + if t, ok := m.(Tracer); ok { + return t.Trace() + } r, c := a.Dims() if r != c { panic(ErrSquare) } - - aU, _ := untranspose(a) - switch m := aU.(type) { - case RawMatrixer: - rm := m.RawMatrix() - var t float64 - for i := 0; i < r; i++ { - t += rm.Data[i*rm.Stride+i] - } - return t - case RawTriangular: - rm := m.RawTriangular() - var t float64 - for i := 0; i < r; i++ { - t += rm.Data[i*rm.Stride+i] - } - return t - case RawSymmetricer: - rm := m.RawSymmetric() - var t float64 - for i := 0; i < r; i++ { - t += rm.Data[i*rm.Stride+i] - } - return t - default: - var t float64 - for i := 0; i < r; i++ { - t += a.At(i, i) - } - return t + var v float64 + for i := 0; i < r; i++ { + v += a.At(i, i) } + return v } func min(a, b int) int { diff --git a/mat/symband.go b/mat/symband.go index a6b71030..add9a807 100644 --- a/mat/symband.go +++ b/mat/symband.go @@ -145,6 +145,18 @@ func (s *SymBandDense) RawSymBand() blas64.SymmetricBand { return s.mat } +// SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver. +// Changes to elements in the receiver following the call will be reflected +// in the input. +// +// The supplied SymmetricBand must use blas.Upper storage format. +func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) { + if mat.Uplo != blas.Upper { + panic("mat: blas64.SymmetricBand does not have blas.Upper storage") + } + s.mat = mat +} + // Zero sets all of the matrix elements to zero. func (s *SymBandDense) Zero() { for i := 0; i < s.mat.N; i++ { diff --git a/mat/symmetric.go b/mat/symmetric.go index c859b534..2ea5bdb0 100644 --- a/mat/symmetric.go +++ b/mat/symmetric.go @@ -113,13 +113,14 @@ func (s *SymDense) RawSymmetric() blas64.Symmetric { // SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver. // Changes to elements in the receiver following the call will be reflected -// in b. SetRawSymmetric will panic if b is not an upper-encoded symmetric -// matrix. -func (s *SymDense) SetRawSymmetric(b blas64.Symmetric) { - if b.Uplo != blas.Upper { +// in the input. +// +// The supplied Symmetric must use blas.Upper storage format. +func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) { + if mat.Uplo != blas.Upper { panic(badSymTriangle) } - s.mat = b + s.mat = mat } // Reset zeros the dimensions of the matrix so that it can be reused as the @@ -514,6 +515,16 @@ func (s *SymDense) SliceSym(i, k int) Symmetric { return &v } +// Trace returns the trace of the matrix. +func (s *SymDense) Trace() float64 { + // TODO(btracey): could use internal asm sum routine. + var v float64 + for i := 0; i < s.mat.N; i++ { + v += s.mat.Data[i*s.mat.Stride+i] + } + return v +} + // GrowSym returns the receiver expanded by n rows and n columns. If the // dimensions of the expanded matrix are outside the capacity of the receiver // a new allocation is made, otherwise not. Note that the receiver itself is diff --git a/mat/triangular.go b/mat/triangular.go index a98cf3a3..e32ee405 100644 --- a/mat/triangular.go +++ b/mat/triangular.go @@ -217,6 +217,18 @@ func (t *TriDense) RawTriangular() blas64.Triangular { return t.mat } +// SetRawTriangular sets the underlying blas64.Triangular used by the receiver. +// Changes to elements in the receiver following the call will be reflected +// in the input. +// +// The supplied Triangular must not use blas.Unit storage format. +func (t *TriDense) SetRawTriangular(mat blas64.Triangular) { + if mat.Diag == blas.Unit { + panic("mat: cannot set TriDense with Unit storage format") + } + t.mat = mat +} + // Reset zeros the dimensions of the matrix so that it can be reused as the // receiver of a dimensionally restricted operation. // @@ -513,6 +525,16 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) { } } +// Trace returns the trace of the matrix. +func (t *TriDense) Trace() float64 { + // TODO(btracey): could use internal asm sum routine. + var v float64 + for i := 0; i < t.mat.N; i++ { + v += t.mat.Data[i*t.mat.Stride+i] + } + return v +} + // copySymIntoTriangle copies a symmetric matrix into a TriDense func copySymIntoTriangle(t *TriDense, s Symmetric) { n, upper := t.Triangle() diff --git a/mat/triband.go b/mat/triband.go index 71ad38ff..f9785504 100644 --- a/mat/triband.go +++ b/mat/triband.go @@ -321,6 +321,18 @@ func (t *TriBandDense) RawTriBand() blas64.TriangularBand { return t.mat } +// SetRawTriBand sets the underlying blas64.TriangularBand used by the receiver. +// Changes to elements in the receiver following the call will be reflected +// in the input. +// +// The supplied TriangularBand must not use blas.Unit storage format. +func (t *TriBandDense) SetRawTriBand(mat blas64.TriangularBand) { + if mat.Diag == blas.Unit { + panic("mat: cannot set TriBand with Unit storage") + } + t.mat = mat +} + // DiagView returns the diagonal as a matrix backed by the original data. func (t *TriBandDense) DiagView() Diagonal { if t.mat.Diag == blas.Unit {