mirror of
https://github.com/gonum/gonum.git
synced 2025-10-07 08:01:20 +08:00
mat: implement helper routines for type extraction and update Trace to use an interface (#932)
* Implement helper routines for type extraction and update Trace to use an interface. Updates #929.
This commit is contained in:
@@ -188,6 +188,13 @@ func (b *BandDense) RawBand() blas64.Band {
|
||||
return b.mat
|
||||
}
|
||||
|
||||
// SetRawBand sets the underlying blas64.Band used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in the input.
|
||||
func (b *BandDense) SetRawBand(mat blas64.Band) {
|
||||
b.mat = mat
|
||||
}
|
||||
|
||||
// DiagView returns the diagonal as a matrix backed by the original data.
|
||||
func (b *BandDense) DiagView() Diagonal {
|
||||
n := min(b.mat.Rows, b.mat.Cols)
|
||||
|
24
mat/dense.go
24
mat/dense.go
@@ -134,16 +134,6 @@ func (m *Dense) Zero() {
|
||||
}
|
||||
}
|
||||
|
||||
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
|
||||
// untranspose returns the underlying matrix and true. If it is not, then it returns
|
||||
// the input matrix and false.
|
||||
func untranspose(a Matrix) (Matrix, bool) {
|
||||
if ut, ok := a.(Untransposer); ok {
|
||||
return ut.Untranspose(), true
|
||||
}
|
||||
return a, false
|
||||
}
|
||||
|
||||
// isolatedWorkspace returns a new dense matrix w with the size of a and
|
||||
// returns a callback to defer which performs cleanup at the return of the call.
|
||||
// This should be used when a method receiver is the same pointer as an input argument.
|
||||
@@ -552,3 +542,17 @@ func (m *Dense) Augment(a, b Matrix) {
|
||||
w := m.Slice(0, br, ac, ac+bc).(*Dense)
|
||||
w.Copy(b)
|
||||
}
|
||||
|
||||
// Trace returns the trace of the matrix. The matrix must be square or Trace
|
||||
// will panic.
|
||||
func (m *Dense) Trace() float64 {
|
||||
if m.mat.Rows != m.mat.Cols {
|
||||
panic(ErrSquare)
|
||||
}
|
||||
// TODO(btracey): could use internal asm sum routine.
|
||||
var v float64
|
||||
for i := 0; i < m.mat.Rows; i++ {
|
||||
v += m.mat.Data[i*m.mat.Stride+i]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
109
mat/matrix.go
109
mat/matrix.go
@@ -207,6 +207,73 @@ type ColNonZeroDoer interface {
|
||||
DoColNonZero(j int, fn func(i, j int, v float64))
|
||||
}
|
||||
|
||||
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
|
||||
// untranspose returns the underlying matrix and true. If it is not, then it returns
|
||||
// the input matrix and false.
|
||||
func untranspose(a Matrix) (Matrix, bool) {
|
||||
if ut, ok := a.(Untransposer); ok {
|
||||
return ut.Untranspose(), true
|
||||
}
|
||||
return a, false
|
||||
}
|
||||
|
||||
// untransposeExtract returns an untransposed matrix in a built-in matrix type.
|
||||
//
|
||||
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
|
||||
// Otherwise, if it implements a Raw method, an appropriate built-in type value
|
||||
// is returned holding the raw matrix value of the input. If neither of these
|
||||
// is possible, the untransposed matrix is returned.
|
||||
func untransposeExtract(a Matrix) (Matrix, bool) {
|
||||
ut, trans := untranspose(a)
|
||||
switch m := ut.(type) {
|
||||
case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
|
||||
return m, trans
|
||||
// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
|
||||
case RawSymBander:
|
||||
rsb := m.RawSymBand()
|
||||
if rsb.Uplo != blas.Upper {
|
||||
return ut, trans
|
||||
}
|
||||
var sb SymBandDense
|
||||
sb.SetRawSymBand(rsb)
|
||||
return &sb, trans
|
||||
case RawTriBander:
|
||||
rtb := m.RawTriBand()
|
||||
if rtb.Diag == blas.Unit {
|
||||
return ut, trans
|
||||
}
|
||||
var tb TriBandDense
|
||||
tb.SetRawTriBand(rtb)
|
||||
return &tb, trans
|
||||
case RawBander:
|
||||
var b BandDense
|
||||
b.SetRawBand(m.RawBand())
|
||||
return &b, trans
|
||||
case RawTriangular:
|
||||
rt := m.RawTriangular()
|
||||
if rt.Diag == blas.Unit {
|
||||
return ut, trans
|
||||
}
|
||||
var t TriDense
|
||||
t.SetRawTriangular(rt)
|
||||
return &t, trans
|
||||
case RawSymmetricer:
|
||||
rs := m.RawSymmetric()
|
||||
if rs.Uplo != blas.Upper {
|
||||
return ut, trans
|
||||
}
|
||||
var s SymDense
|
||||
s.SetRawSymmetric(rs)
|
||||
return &s, trans
|
||||
case RawMatrixer:
|
||||
var d Dense
|
||||
d.SetRawMatrix(m.RawMatrix())
|
||||
return &d, trans
|
||||
default:
|
||||
return ut, trans
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
|
||||
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
|
||||
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
|
||||
@@ -803,44 +870,28 @@ func Sum(a Matrix) float64 {
|
||||
return sum
|
||||
}
|
||||
|
||||
// A Tracer can compute the trace of the matrix. Trace must panic if the
|
||||
// matrix is not square.
|
||||
type Tracer interface {
|
||||
Trace() float64
|
||||
}
|
||||
|
||||
// Trace returns the trace of the matrix. Trace will panic if the
|
||||
// matrix is not square.
|
||||
func Trace(a Matrix) float64 {
|
||||
m, _ := untransposeExtract(a)
|
||||
if t, ok := m.(Tracer); ok {
|
||||
return t.Trace()
|
||||
}
|
||||
r, c := a.Dims()
|
||||
if r != c {
|
||||
panic(ErrSquare)
|
||||
}
|
||||
|
||||
aU, _ := untranspose(a)
|
||||
switch m := aU.(type) {
|
||||
case RawMatrixer:
|
||||
rm := m.RawMatrix()
|
||||
var t float64
|
||||
var v float64
|
||||
for i := 0; i < r; i++ {
|
||||
t += rm.Data[i*rm.Stride+i]
|
||||
}
|
||||
return t
|
||||
case RawTriangular:
|
||||
rm := m.RawTriangular()
|
||||
var t float64
|
||||
for i := 0; i < r; i++ {
|
||||
t += rm.Data[i*rm.Stride+i]
|
||||
}
|
||||
return t
|
||||
case RawSymmetricer:
|
||||
rm := m.RawSymmetric()
|
||||
var t float64
|
||||
for i := 0; i < r; i++ {
|
||||
t += rm.Data[i*rm.Stride+i]
|
||||
}
|
||||
return t
|
||||
default:
|
||||
var t float64
|
||||
for i := 0; i < r; i++ {
|
||||
t += a.At(i, i)
|
||||
}
|
||||
return t
|
||||
v += a.At(i, i)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
|
@@ -145,6 +145,18 @@ func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
|
||||
return s.mat
|
||||
}
|
||||
|
||||
// SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in the input.
|
||||
//
|
||||
// The supplied SymmetricBand must use blas.Upper storage format.
|
||||
func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
|
||||
if mat.Uplo != blas.Upper {
|
||||
panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
|
||||
}
|
||||
s.mat = mat
|
||||
}
|
||||
|
||||
// Zero sets all of the matrix elements to zero.
|
||||
func (s *SymBandDense) Zero() {
|
||||
for i := 0; i < s.mat.N; i++ {
|
||||
|
@@ -113,13 +113,14 @@ func (s *SymDense) RawSymmetric() blas64.Symmetric {
|
||||
|
||||
// SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in b. SetRawSymmetric will panic if b is not an upper-encoded symmetric
|
||||
// matrix.
|
||||
func (s *SymDense) SetRawSymmetric(b blas64.Symmetric) {
|
||||
if b.Uplo != blas.Upper {
|
||||
// in the input.
|
||||
//
|
||||
// The supplied Symmetric must use blas.Upper storage format.
|
||||
func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) {
|
||||
if mat.Uplo != blas.Upper {
|
||||
panic(badSymTriangle)
|
||||
}
|
||||
s.mat = b
|
||||
s.mat = mat
|
||||
}
|
||||
|
||||
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
||||
@@ -514,6 +515,16 @@ func (s *SymDense) SliceSym(i, k int) Symmetric {
|
||||
return &v
|
||||
}
|
||||
|
||||
// Trace returns the trace of the matrix.
|
||||
func (s *SymDense) Trace() float64 {
|
||||
// TODO(btracey): could use internal asm sum routine.
|
||||
var v float64
|
||||
for i := 0; i < s.mat.N; i++ {
|
||||
v += s.mat.Data[i*s.mat.Stride+i]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GrowSym returns the receiver expanded by n rows and n columns. If the
|
||||
// dimensions of the expanded matrix are outside the capacity of the receiver
|
||||
// a new allocation is made, otherwise not. Note that the receiver itself is
|
||||
|
@@ -217,6 +217,18 @@ func (t *TriDense) RawTriangular() blas64.Triangular {
|
||||
return t.mat
|
||||
}
|
||||
|
||||
// SetRawTriangular sets the underlying blas64.Triangular used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in the input.
|
||||
//
|
||||
// The supplied Triangular must not use blas.Unit storage format.
|
||||
func (t *TriDense) SetRawTriangular(mat blas64.Triangular) {
|
||||
if mat.Diag == blas.Unit {
|
||||
panic("mat: cannot set TriDense with Unit storage format")
|
||||
}
|
||||
t.mat = mat
|
||||
}
|
||||
|
||||
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
||||
// receiver of a dimensionally restricted operation.
|
||||
//
|
||||
@@ -513,6 +525,16 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
||||
}
|
||||
}
|
||||
|
||||
// Trace returns the trace of the matrix.
|
||||
func (t *TriDense) Trace() float64 {
|
||||
// TODO(btracey): could use internal asm sum routine.
|
||||
var v float64
|
||||
for i := 0; i < t.mat.N; i++ {
|
||||
v += t.mat.Data[i*t.mat.Stride+i]
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// copySymIntoTriangle copies a symmetric matrix into a TriDense
|
||||
func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
||||
n, upper := t.Triangle()
|
||||
|
@@ -321,6 +321,18 @@ func (t *TriBandDense) RawTriBand() blas64.TriangularBand {
|
||||
return t.mat
|
||||
}
|
||||
|
||||
// SetRawTriBand sets the underlying blas64.TriangularBand used by the receiver.
|
||||
// Changes to elements in the receiver following the call will be reflected
|
||||
// in the input.
|
||||
//
|
||||
// The supplied TriangularBand must not use blas.Unit storage format.
|
||||
func (t *TriBandDense) SetRawTriBand(mat blas64.TriangularBand) {
|
||||
if mat.Diag == blas.Unit {
|
||||
panic("mat: cannot set TriBand with Unit storage")
|
||||
}
|
||||
t.mat = mat
|
||||
}
|
||||
|
||||
// DiagView returns the diagonal as a matrix backed by the original data.
|
||||
func (t *TriBandDense) DiagView() Diagonal {
|
||||
if t.mat.Diag == blas.Unit {
|
||||
|
Reference in New Issue
Block a user