mat: implement helper routines for type extraction and update Trace to use an interface (#932)

* Implement helper routines for type extraction and update Trace to use an interface.

Updates #929.
This commit is contained in:
Brendan Tracey
2019-03-31 09:26:36 +01:00
committed by GitHub
parent a4ad4d254f
commit 9a0642d3dd
7 changed files with 164 additions and 45 deletions

View File

@@ -188,6 +188,13 @@ func (b *BandDense) RawBand() blas64.Band {
return b.mat return b.mat
} }
// SetRawBand sets the underlying blas64.Band used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in the input.
func (b *BandDense) SetRawBand(mat blas64.Band) {
b.mat = mat
}
// DiagView returns the diagonal as a matrix backed by the original data. // DiagView returns the diagonal as a matrix backed by the original data.
func (b *BandDense) DiagView() Diagonal { func (b *BandDense) DiagView() Diagonal {
n := min(b.mat.Rows, b.mat.Cols) n := min(b.mat.Rows, b.mat.Cols)

View File

@@ -134,16 +134,6 @@ func (m *Dense) Zero() {
} }
} }
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
// untranspose returns the underlying matrix and true. If it is not, then it returns
// the input matrix and false.
func untranspose(a Matrix) (Matrix, bool) {
if ut, ok := a.(Untransposer); ok {
return ut.Untranspose(), true
}
return a, false
}
// isolatedWorkspace returns a new dense matrix w with the size of a and // isolatedWorkspace returns a new dense matrix w with the size of a and
// returns a callback to defer which performs cleanup at the return of the call. // returns a callback to defer which performs cleanup at the return of the call.
// This should be used when a method receiver is the same pointer as an input argument. // This should be used when a method receiver is the same pointer as an input argument.
@@ -552,3 +542,17 @@ func (m *Dense) Augment(a, b Matrix) {
w := m.Slice(0, br, ac, ac+bc).(*Dense) w := m.Slice(0, br, ac, ac+bc).(*Dense)
w.Copy(b) w.Copy(b)
} }
// Trace returns the trace of the matrix. The matrix must be square or Trace
// will panic.
func (m *Dense) Trace() float64 {
if m.mat.Rows != m.mat.Cols {
panic(ErrSquare)
}
// TODO(btracey): could use internal asm sum routine.
var v float64
for i := 0; i < m.mat.Rows; i++ {
v += m.mat.Data[i*m.mat.Stride+i]
}
return v
}

View File

@@ -207,6 +207,73 @@ type ColNonZeroDoer interface {
DoColNonZero(j int, fn func(i, j int, v float64)) DoColNonZero(j int, fn func(i, j int, v float64))
} }
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
// untranspose returns the underlying matrix and true. If it is not, then it returns
// the input matrix and false.
func untranspose(a Matrix) (Matrix, bool) {
if ut, ok := a.(Untransposer); ok {
return ut.Untranspose(), true
}
return a, false
}
// untransposeExtract returns an untransposed matrix in a built-in matrix type.
//
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
// Otherwise, if it implements a Raw method, an appropriate built-in type value
// is returned holding the raw matrix value of the input. If neither of these
// is possible, the untransposed matrix is returned.
func untransposeExtract(a Matrix) (Matrix, bool) {
ut, trans := untranspose(a)
switch m := ut.(type) {
case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
return m, trans
// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
case RawSymBander:
rsb := m.RawSymBand()
if rsb.Uplo != blas.Upper {
return ut, trans
}
var sb SymBandDense
sb.SetRawSymBand(rsb)
return &sb, trans
case RawTriBander:
rtb := m.RawTriBand()
if rtb.Diag == blas.Unit {
return ut, trans
}
var tb TriBandDense
tb.SetRawTriBand(rtb)
return &tb, trans
case RawBander:
var b BandDense
b.SetRawBand(m.RawBand())
return &b, trans
case RawTriangular:
rt := m.RawTriangular()
if rt.Diag == blas.Unit {
return ut, trans
}
var t TriDense
t.SetRawTriangular(rt)
return &t, trans
case RawSymmetricer:
rs := m.RawSymmetric()
if rs.Uplo != blas.Upper {
return ut, trans
}
var s SymDense
s.SetRawSymmetric(rs)
return &s, trans
case RawMatrixer:
var d Dense
d.SetRawMatrix(m.RawMatrix())
return &d, trans
default:
return ut, trans
}
}
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful. // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.) // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
@@ -803,44 +870,28 @@ func Sum(a Matrix) float64 {
return sum return sum
} }
// A Tracer can compute the trace of the matrix. Trace must panic if the
// matrix is not square.
type Tracer interface {
Trace() float64
}
// Trace returns the trace of the matrix. Trace will panic if the // Trace returns the trace of the matrix. Trace will panic if the
// matrix is not square. // matrix is not square.
func Trace(a Matrix) float64 { func Trace(a Matrix) float64 {
m, _ := untransposeExtract(a)
if t, ok := m.(Tracer); ok {
return t.Trace()
}
r, c := a.Dims() r, c := a.Dims()
if r != c { if r != c {
panic(ErrSquare) panic(ErrSquare)
} }
var v float64
aU, _ := untranspose(a) for i := 0; i < r; i++ {
switch m := aU.(type) { v += a.At(i, i)
case RawMatrixer:
rm := m.RawMatrix()
var t float64
for i := 0; i < r; i++ {
t += rm.Data[i*rm.Stride+i]
}
return t
case RawTriangular:
rm := m.RawTriangular()
var t float64
for i := 0; i < r; i++ {
t += rm.Data[i*rm.Stride+i]
}
return t
case RawSymmetricer:
rm := m.RawSymmetric()
var t float64
for i := 0; i < r; i++ {
t += rm.Data[i*rm.Stride+i]
}
return t
default:
var t float64
for i := 0; i < r; i++ {
t += a.At(i, i)
}
return t
} }
return v
} }
func min(a, b int) int { func min(a, b int) int {

View File

@@ -145,6 +145,18 @@ func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
return s.mat return s.mat
} }
// SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in the input.
//
// The supplied SymmetricBand must use blas.Upper storage format.
func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
if mat.Uplo != blas.Upper {
panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
}
s.mat = mat
}
// Zero sets all of the matrix elements to zero. // Zero sets all of the matrix elements to zero.
func (s *SymBandDense) Zero() { func (s *SymBandDense) Zero() {
for i := 0; i < s.mat.N; i++ { for i := 0; i < s.mat.N; i++ {

View File

@@ -113,13 +113,14 @@ func (s *SymDense) RawSymmetric() blas64.Symmetric {
// SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver. // SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
// Changes to elements in the receiver following the call will be reflected // Changes to elements in the receiver following the call will be reflected
// in b. SetRawSymmetric will panic if b is not an upper-encoded symmetric // in the input.
// matrix. //
func (s *SymDense) SetRawSymmetric(b blas64.Symmetric) { // The supplied Symmetric must use blas.Upper storage format.
if b.Uplo != blas.Upper { func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) {
if mat.Uplo != blas.Upper {
panic(badSymTriangle) panic(badSymTriangle)
} }
s.mat = b s.mat = mat
} }
// Reset zeros the dimensions of the matrix so that it can be reused as the // Reset zeros the dimensions of the matrix so that it can be reused as the
@@ -514,6 +515,16 @@ func (s *SymDense) SliceSym(i, k int) Symmetric {
return &v return &v
} }
// Trace returns the trace of the matrix.
func (s *SymDense) Trace() float64 {
// TODO(btracey): could use internal asm sum routine.
var v float64
for i := 0; i < s.mat.N; i++ {
v += s.mat.Data[i*s.mat.Stride+i]
}
return v
}
// GrowSym returns the receiver expanded by n rows and n columns. If the // GrowSym returns the receiver expanded by n rows and n columns. If the
// dimensions of the expanded matrix are outside the capacity of the receiver // dimensions of the expanded matrix are outside the capacity of the receiver
// a new allocation is made, otherwise not. Note that the receiver itself is // a new allocation is made, otherwise not. Note that the receiver itself is

View File

@@ -217,6 +217,18 @@ func (t *TriDense) RawTriangular() blas64.Triangular {
return t.mat return t.mat
} }
// SetRawTriangular sets the underlying blas64.Triangular used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in the input.
//
// The supplied Triangular must not use blas.Unit storage format.
func (t *TriDense) SetRawTriangular(mat blas64.Triangular) {
if mat.Diag == blas.Unit {
panic("mat: cannot set TriDense with Unit storage format")
}
t.mat = mat
}
// Reset zeros the dimensions of the matrix so that it can be reused as the // Reset zeros the dimensions of the matrix so that it can be reused as the
// receiver of a dimensionally restricted operation. // receiver of a dimensionally restricted operation.
// //
@@ -513,6 +525,16 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
} }
} }
// Trace returns the trace of the matrix.
func (t *TriDense) Trace() float64 {
// TODO(btracey): could use internal asm sum routine.
var v float64
for i := 0; i < t.mat.N; i++ {
v += t.mat.Data[i*t.mat.Stride+i]
}
return v
}
// copySymIntoTriangle copies a symmetric matrix into a TriDense // copySymIntoTriangle copies a symmetric matrix into a TriDense
func copySymIntoTriangle(t *TriDense, s Symmetric) { func copySymIntoTriangle(t *TriDense, s Symmetric) {
n, upper := t.Triangle() n, upper := t.Triangle()

View File

@@ -321,6 +321,18 @@ func (t *TriBandDense) RawTriBand() blas64.TriangularBand {
return t.mat return t.mat
} }
// SetRawTriBand sets the underlying blas64.TriangularBand used by the receiver.
// Changes to elements in the receiver following the call will be reflected
// in the input.
//
// The supplied TriangularBand must not use blas.Unit storage format.
func (t *TriBandDense) SetRawTriBand(mat blas64.TriangularBand) {
if mat.Diag == blas.Unit {
panic("mat: cannot set TriBand with Unit storage")
}
t.mat = mat
}
// DiagView returns the diagonal as a matrix backed by the original data. // DiagView returns the diagonal as a matrix backed by the original data.
func (t *TriBandDense) DiagView() Diagonal { func (t *TriBandDense) DiagView() Diagonal {
if t.mat.Diag == blas.Unit { if t.mat.Diag == blas.Unit {