mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: implement helper routines for type extraction and update Trace to use an interface (#932)
* Implement helper routines for type extraction and update Trace to use an interface. Updates #929.
This commit is contained in:
@@ -188,6 +188,13 @@ func (b *BandDense) RawBand() blas64.Band {
|
|||||||
return b.mat
|
return b.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRawBand sets the underlying blas64.Band used by the receiver.
|
||||||
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
|
// in the input.
|
||||||
|
func (b *BandDense) SetRawBand(mat blas64.Band) {
|
||||||
|
b.mat = mat
|
||||||
|
}
|
||||||
|
|
||||||
// DiagView returns the diagonal as a matrix backed by the original data.
|
// DiagView returns the diagonal as a matrix backed by the original data.
|
||||||
func (b *BandDense) DiagView() Diagonal {
|
func (b *BandDense) DiagView() Diagonal {
|
||||||
n := min(b.mat.Rows, b.mat.Cols)
|
n := min(b.mat.Rows, b.mat.Cols)
|
||||||
|
24
mat/dense.go
24
mat/dense.go
@@ -134,16 +134,6 @@ func (m *Dense) Zero() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
|
|
||||||
// untranspose returns the underlying matrix and true. If it is not, then it returns
|
|
||||||
// the input matrix and false.
|
|
||||||
func untranspose(a Matrix) (Matrix, bool) {
|
|
||||||
if ut, ok := a.(Untransposer); ok {
|
|
||||||
return ut.Untranspose(), true
|
|
||||||
}
|
|
||||||
return a, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isolatedWorkspace returns a new dense matrix w with the size of a and
|
// isolatedWorkspace returns a new dense matrix w with the size of a and
|
||||||
// returns a callback to defer which performs cleanup at the return of the call.
|
// returns a callback to defer which performs cleanup at the return of the call.
|
||||||
// This should be used when a method receiver is the same pointer as an input argument.
|
// This should be used when a method receiver is the same pointer as an input argument.
|
||||||
@@ -552,3 +542,17 @@ func (m *Dense) Augment(a, b Matrix) {
|
|||||||
w := m.Slice(0, br, ac, ac+bc).(*Dense)
|
w := m.Slice(0, br, ac, ac+bc).(*Dense)
|
||||||
w.Copy(b)
|
w.Copy(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace returns the trace of the matrix. The matrix must be square or Trace
|
||||||
|
// will panic.
|
||||||
|
func (m *Dense) Trace() float64 {
|
||||||
|
if m.mat.Rows != m.mat.Cols {
|
||||||
|
panic(ErrSquare)
|
||||||
|
}
|
||||||
|
// TODO(btracey): could use internal asm sum routine.
|
||||||
|
var v float64
|
||||||
|
for i := 0; i < m.mat.Rows; i++ {
|
||||||
|
v += m.mat.Data[i*m.mat.Stride+i]
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
111
mat/matrix.go
111
mat/matrix.go
@@ -207,6 +207,73 @@ type ColNonZeroDoer interface {
|
|||||||
DoColNonZero(j int, fn func(i, j int, v float64))
|
DoColNonZero(j int, fn func(i, j int, v float64))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// untranspose untransposes a matrix if applicable. If a is an Untransposer, then
|
||||||
|
// untranspose returns the underlying matrix and true. If it is not, then it returns
|
||||||
|
// the input matrix and false.
|
||||||
|
func untranspose(a Matrix) (Matrix, bool) {
|
||||||
|
if ut, ok := a.(Untransposer); ok {
|
||||||
|
return ut.Untranspose(), true
|
||||||
|
}
|
||||||
|
return a, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// untransposeExtract returns an untransposed matrix in a built-in matrix type.
|
||||||
|
//
|
||||||
|
// The untransposed matrix is returned unaltered if it is a built-in matrix type.
|
||||||
|
// Otherwise, if it implements a Raw method, an appropriate built-in type value
|
||||||
|
// is returned holding the raw matrix value of the input. If neither of these
|
||||||
|
// is possible, the untransposed matrix is returned.
|
||||||
|
func untransposeExtract(a Matrix) (Matrix, bool) {
|
||||||
|
ut, trans := untranspose(a)
|
||||||
|
switch m := ut.(type) {
|
||||||
|
case *DiagDense, *SymBandDense, *TriBandDense, *BandDense, *TriDense, *SymDense, *Dense:
|
||||||
|
return m, trans
|
||||||
|
// TODO(btracey): Add here if we ever have an equivalent of RawDiagDense.
|
||||||
|
case RawSymBander:
|
||||||
|
rsb := m.RawSymBand()
|
||||||
|
if rsb.Uplo != blas.Upper {
|
||||||
|
return ut, trans
|
||||||
|
}
|
||||||
|
var sb SymBandDense
|
||||||
|
sb.SetRawSymBand(rsb)
|
||||||
|
return &sb, trans
|
||||||
|
case RawTriBander:
|
||||||
|
rtb := m.RawTriBand()
|
||||||
|
if rtb.Diag == blas.Unit {
|
||||||
|
return ut, trans
|
||||||
|
}
|
||||||
|
var tb TriBandDense
|
||||||
|
tb.SetRawTriBand(rtb)
|
||||||
|
return &tb, trans
|
||||||
|
case RawBander:
|
||||||
|
var b BandDense
|
||||||
|
b.SetRawBand(m.RawBand())
|
||||||
|
return &b, trans
|
||||||
|
case RawTriangular:
|
||||||
|
rt := m.RawTriangular()
|
||||||
|
if rt.Diag == blas.Unit {
|
||||||
|
return ut, trans
|
||||||
|
}
|
||||||
|
var t TriDense
|
||||||
|
t.SetRawTriangular(rt)
|
||||||
|
return &t, trans
|
||||||
|
case RawSymmetricer:
|
||||||
|
rs := m.RawSymmetric()
|
||||||
|
if rs.Uplo != blas.Upper {
|
||||||
|
return ut, trans
|
||||||
|
}
|
||||||
|
var s SymDense
|
||||||
|
s.SetRawSymmetric(rs)
|
||||||
|
return &s, trans
|
||||||
|
case RawMatrixer:
|
||||||
|
var d Dense
|
||||||
|
d.SetRawMatrix(m.RawMatrix())
|
||||||
|
return &d, trans
|
||||||
|
default:
|
||||||
|
return ut, trans
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
|
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
|
||||||
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
|
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
|
||||||
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
|
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
|
||||||
@@ -803,44 +870,28 @@ func Sum(a Matrix) float64 {
|
|||||||
return sum
|
return sum
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A Tracer can compute the trace of the matrix. Trace must panic if the
|
||||||
|
// matrix is not square.
|
||||||
|
type Tracer interface {
|
||||||
|
Trace() float64
|
||||||
|
}
|
||||||
|
|
||||||
// Trace returns the trace of the matrix. Trace will panic if the
|
// Trace returns the trace of the matrix. Trace will panic if the
|
||||||
// matrix is not square.
|
// matrix is not square.
|
||||||
func Trace(a Matrix) float64 {
|
func Trace(a Matrix) float64 {
|
||||||
|
m, _ := untransposeExtract(a)
|
||||||
|
if t, ok := m.(Tracer); ok {
|
||||||
|
return t.Trace()
|
||||||
|
}
|
||||||
r, c := a.Dims()
|
r, c := a.Dims()
|
||||||
if r != c {
|
if r != c {
|
||||||
panic(ErrSquare)
|
panic(ErrSquare)
|
||||||
}
|
}
|
||||||
|
var v float64
|
||||||
aU, _ := untranspose(a)
|
for i := 0; i < r; i++ {
|
||||||
switch m := aU.(type) {
|
v += a.At(i, i)
|
||||||
case RawMatrixer:
|
|
||||||
rm := m.RawMatrix()
|
|
||||||
var t float64
|
|
||||||
for i := 0; i < r; i++ {
|
|
||||||
t += rm.Data[i*rm.Stride+i]
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
case RawTriangular:
|
|
||||||
rm := m.RawTriangular()
|
|
||||||
var t float64
|
|
||||||
for i := 0; i < r; i++ {
|
|
||||||
t += rm.Data[i*rm.Stride+i]
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
case RawSymmetricer:
|
|
||||||
rm := m.RawSymmetric()
|
|
||||||
var t float64
|
|
||||||
for i := 0; i < r; i++ {
|
|
||||||
t += rm.Data[i*rm.Stride+i]
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
default:
|
|
||||||
var t float64
|
|
||||||
for i := 0; i < r; i++ {
|
|
||||||
t += a.At(i, i)
|
|
||||||
}
|
|
||||||
return t
|
|
||||||
}
|
}
|
||||||
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b int) int {
|
func min(a, b int) int {
|
||||||
|
@@ -145,6 +145,18 @@ func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
|
|||||||
return s.mat
|
return s.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRawSymBand sets the underlying blas64.SymmetricBand used by the receiver.
|
||||||
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
|
// in the input.
|
||||||
|
//
|
||||||
|
// The supplied SymmetricBand must use blas.Upper storage format.
|
||||||
|
func (s *SymBandDense) SetRawSymBand(mat blas64.SymmetricBand) {
|
||||||
|
if mat.Uplo != blas.Upper {
|
||||||
|
panic("mat: blas64.SymmetricBand does not have blas.Upper storage")
|
||||||
|
}
|
||||||
|
s.mat = mat
|
||||||
|
}
|
||||||
|
|
||||||
// Zero sets all of the matrix elements to zero.
|
// Zero sets all of the matrix elements to zero.
|
||||||
func (s *SymBandDense) Zero() {
|
func (s *SymBandDense) Zero() {
|
||||||
for i := 0; i < s.mat.N; i++ {
|
for i := 0; i < s.mat.N; i++ {
|
||||||
|
@@ -113,13 +113,14 @@ func (s *SymDense) RawSymmetric() blas64.Symmetric {
|
|||||||
|
|
||||||
// SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
|
// SetRawSymmetric sets the underlying blas64.Symmetric used by the receiver.
|
||||||
// Changes to elements in the receiver following the call will be reflected
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
// in b. SetRawSymmetric will panic if b is not an upper-encoded symmetric
|
// in the input.
|
||||||
// matrix.
|
//
|
||||||
func (s *SymDense) SetRawSymmetric(b blas64.Symmetric) {
|
// The supplied Symmetric must use blas.Upper storage format.
|
||||||
if b.Uplo != blas.Upper {
|
func (s *SymDense) SetRawSymmetric(mat blas64.Symmetric) {
|
||||||
|
if mat.Uplo != blas.Upper {
|
||||||
panic(badSymTriangle)
|
panic(badSymTriangle)
|
||||||
}
|
}
|
||||||
s.mat = b
|
s.mat = mat
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
||||||
@@ -514,6 +515,16 @@ func (s *SymDense) SliceSym(i, k int) Symmetric {
|
|||||||
return &v
|
return &v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace returns the trace of the matrix.
|
||||||
|
func (s *SymDense) Trace() float64 {
|
||||||
|
// TODO(btracey): could use internal asm sum routine.
|
||||||
|
var v float64
|
||||||
|
for i := 0; i < s.mat.N; i++ {
|
||||||
|
v += s.mat.Data[i*s.mat.Stride+i]
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
// GrowSym returns the receiver expanded by n rows and n columns. If the
|
// GrowSym returns the receiver expanded by n rows and n columns. If the
|
||||||
// dimensions of the expanded matrix are outside the capacity of the receiver
|
// dimensions of the expanded matrix are outside the capacity of the receiver
|
||||||
// a new allocation is made, otherwise not. Note that the receiver itself is
|
// a new allocation is made, otherwise not. Note that the receiver itself is
|
||||||
|
@@ -217,6 +217,18 @@ func (t *TriDense) RawTriangular() blas64.Triangular {
|
|||||||
return t.mat
|
return t.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRawTriangular sets the underlying blas64.Triangular used by the receiver.
|
||||||
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
|
// in the input.
|
||||||
|
//
|
||||||
|
// The supplied Triangular must not use blas.Unit storage format.
|
||||||
|
func (t *TriDense) SetRawTriangular(mat blas64.Triangular) {
|
||||||
|
if mat.Diag == blas.Unit {
|
||||||
|
panic("mat: cannot set TriDense with Unit storage format")
|
||||||
|
}
|
||||||
|
t.mat = mat
|
||||||
|
}
|
||||||
|
|
||||||
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
// Reset zeros the dimensions of the matrix so that it can be reused as the
|
||||||
// receiver of a dimensionally restricted operation.
|
// receiver of a dimensionally restricted operation.
|
||||||
//
|
//
|
||||||
@@ -513,6 +525,16 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trace returns the trace of the matrix.
|
||||||
|
func (t *TriDense) Trace() float64 {
|
||||||
|
// TODO(btracey): could use internal asm sum routine.
|
||||||
|
var v float64
|
||||||
|
for i := 0; i < t.mat.N; i++ {
|
||||||
|
v += t.mat.Data[i*t.mat.Stride+i]
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
// copySymIntoTriangle copies a symmetric matrix into a TriDense
|
// copySymIntoTriangle copies a symmetric matrix into a TriDense
|
||||||
func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
||||||
n, upper := t.Triangle()
|
n, upper := t.Triangle()
|
||||||
|
@@ -321,6 +321,18 @@ func (t *TriBandDense) RawTriBand() blas64.TriangularBand {
|
|||||||
return t.mat
|
return t.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRawTriBand sets the underlying blas64.TriangularBand used by the receiver.
|
||||||
|
// Changes to elements in the receiver following the call will be reflected
|
||||||
|
// in the input.
|
||||||
|
//
|
||||||
|
// The supplied TriangularBand must not use blas.Unit storage format.
|
||||||
|
func (t *TriBandDense) SetRawTriBand(mat blas64.TriangularBand) {
|
||||||
|
if mat.Diag == blas.Unit {
|
||||||
|
panic("mat: cannot set TriBand with Unit storage")
|
||||||
|
}
|
||||||
|
t.mat = mat
|
||||||
|
}
|
||||||
|
|
||||||
// DiagView returns the diagonal as a matrix backed by the original data.
|
// DiagView returns the diagonal as a matrix backed by the original data.
|
||||||
func (t *TriBandDense) DiagView() Diagonal {
|
func (t *TriBandDense) DiagView() Diagonal {
|
||||||
if t.mat.Diag == blas.Unit {
|
if t.mat.Diag == blas.Unit {
|
||||||
|
Reference in New Issue
Block a user