mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 23:52:47 +08:00
mat: add NonZeroDoer interfaces and implementations
Also clean up some documentation and missing type checks related to tests for NonZeroDoers.
This commit is contained in:
53
mat/band.go
53
mat/band.go
@@ -13,6 +13,10 @@ var (
|
|||||||
_ Matrix = bandDense
|
_ Matrix = bandDense
|
||||||
_ Banded = bandDense
|
_ Banded = bandDense
|
||||||
_ RawBander = bandDense
|
_ RawBander = bandDense
|
||||||
|
|
||||||
|
_ NonZeroDoer = bandDense
|
||||||
|
_ RowNonZeroDoer = bandDense
|
||||||
|
_ ColNonZeroDoer = bandDense
|
||||||
)
|
)
|
||||||
|
|
||||||
// BandDense represents a band matrix in dense storage format.
|
// BandDense represents a band matrix in dense storage format.
|
||||||
@@ -39,6 +43,12 @@ type RawBander interface {
|
|||||||
RawBand() blas64.Band
|
RawBand() blas64.Band
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A MutableBanded can set elements of a band matrix.
|
||||||
|
type MutableBanded interface {
|
||||||
|
Banded
|
||||||
|
SetBand(i, j int, v float64)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ Matrix = TransposeBand{}
|
_ Matrix = TransposeBand{}
|
||||||
_ Banded = TransposeBand{}
|
_ Banded = TransposeBand{}
|
||||||
@@ -173,3 +183,46 @@ func (b *BandDense) TBand() Banded {
|
|||||||
func (b *BandDense) RawBand() blas64.Band {
|
func (b *BandDense) RawBand() blas64.Band {
|
||||||
return b.mat
|
return b.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DoNonZero calls the function fn for each of the non-zero elements of b. The function fn
|
||||||
|
// takes a row/column index and the element value of b at (i, j).
|
||||||
|
func (b *BandDense) DoNonZero(fn func(i, j int, v float64)) {
|
||||||
|
for i := 0; i < min(b.mat.Rows, b.mat.Cols+b.mat.KL); i++ {
|
||||||
|
for j := max(0, i-b.mat.KL); j < min(b.mat.Cols, i+b.mat.KU+1); j++ {
|
||||||
|
v := b.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRowNonZero calls the function fn for each of the non-zero elements of row i of b. The function fn
|
||||||
|
// takes a row/column index and the element value of b at (i, j).
|
||||||
|
func (b *BandDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
|
||||||
|
if i < 0 || b.mat.Rows <= i {
|
||||||
|
panic(ErrRowAccess)
|
||||||
|
}
|
||||||
|
for j := max(0, i-b.mat.KL); j < min(b.mat.Cols, i+b.mat.KU+1); j++ {
|
||||||
|
v := b.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoColNonZero calls the function fn for each of the non-zero elements of column j of b. The function fn
|
||||||
|
// takes a row/column index and the element value of b at (i, j).
|
||||||
|
func (b *BandDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
|
||||||
|
if j < 0 || b.mat.Cols <= j {
|
||||||
|
panic(ErrColAccess)
|
||||||
|
}
|
||||||
|
for i := 0; i < min(b.mat.Rows, b.mat.Cols+b.mat.KL); i++ {
|
||||||
|
if i-b.mat.KL <= j && j < i+b.mat.KU+1 {
|
||||||
|
v := b.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -182,6 +182,24 @@ type RawVectorer interface {
|
|||||||
RawVector() blas64.Vector
|
RawVector() blas64.Vector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A NonZeroDoer can call a function for each non-zero element of the receiver.
|
||||||
|
// The parameters of the function are the element indices and its value.
|
||||||
|
type NonZeroDoer interface {
|
||||||
|
DoNonZero(func(i, j int, v float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
|
||||||
|
// The parameters of the function are the element indices and its value.
|
||||||
|
type RowNonZeroDoer interface {
|
||||||
|
DoRowNonZero(i int, fn func(i, j int, v float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
// A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
|
||||||
|
// The parameters of the function are the element indices and its value.
|
||||||
|
type ColNonZeroDoer interface {
|
||||||
|
DoColNonZero(j int, fn func(i, j int, v float64))
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
|
// TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
|
||||||
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
|
// TODO(btracey): Add in fast paths to Row/Col for the other concrete types
|
||||||
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
|
// (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
|
||||||
|
@@ -521,3 +521,94 @@ func TestTrace(t *testing.T) {
|
|||||||
}
|
}
|
||||||
testOneInputFunc(t, "Trace", f, denseComparison, sameAnswerFloat, isAnyType, isSquare)
|
testOneInputFunc(t, "Trace", f, denseComparison, sameAnswerFloat, isAnyType, isSquare)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDoer(t *testing.T) {
|
||||||
|
type MatrixDoer interface {
|
||||||
|
Matrix
|
||||||
|
NonZeroDoer
|
||||||
|
RowNonZeroDoer
|
||||||
|
ColNonZeroDoer
|
||||||
|
}
|
||||||
|
ones := func(n int) []float64 {
|
||||||
|
data := make([]float64, n)
|
||||||
|
for i := range data {
|
||||||
|
data[i] = 1
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
for i, m := range []MatrixDoer{
|
||||||
|
NewTriDense(3, Lower, ones(3*3)),
|
||||||
|
NewTriDense(3, Upper, ones(3*3)),
|
||||||
|
NewBandDense(6, 6, 1, 1, ones(3*6)),
|
||||||
|
NewBandDense(6, 10, 1, 1, ones(3*6)),
|
||||||
|
NewBandDense(10, 6, 1, 1, ones(7*3)),
|
||||||
|
NewSymBandDense(3, 0, ones(3)),
|
||||||
|
NewSymBandDense(3, 1, ones(3*(1+1))),
|
||||||
|
NewSymBandDense(6, 1, ones(6*(1+1))),
|
||||||
|
NewSymBandDense(6, 2, ones(6*(2+1))),
|
||||||
|
} {
|
||||||
|
r, c := m.Dims()
|
||||||
|
|
||||||
|
want := Sum(m)
|
||||||
|
|
||||||
|
// got and fn sum the accessed elements in
|
||||||
|
// the Doer that is being operated on.
|
||||||
|
// fn also tests that the accessed elements
|
||||||
|
// are within the writable areas of the
|
||||||
|
// matrix to check that only valid elements
|
||||||
|
// are operated on.
|
||||||
|
var got float64
|
||||||
|
fn := func(i, j int, v float64) {
|
||||||
|
got += v
|
||||||
|
switch m := m.(type) {
|
||||||
|
case MutableTriangular:
|
||||||
|
m.SetTri(i, j, v)
|
||||||
|
case MutableBanded:
|
||||||
|
m.SetBand(i, j, v)
|
||||||
|
case MutableSymBanded:
|
||||||
|
m.SetSymBand(i, j, v)
|
||||||
|
default:
|
||||||
|
panic("bad test: need mutable type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
panicked, message := panics(func() { m.DoNonZero(fn) })
|
||||||
|
if panicked {
|
||||||
|
t.Errorf("unexpected panic for Doer test %d: %q", i, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("unexpected Doer sum: got:%f want:%f", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset got for testing with DoRowNonZero.
|
||||||
|
got = 0
|
||||||
|
panicked, message = panics(func() {
|
||||||
|
for i := 0; i < r; i++ {
|
||||||
|
m.DoRowNonZero(i, fn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if panicked {
|
||||||
|
t.Errorf("unexpected panic for RowDoer test %d: %q", i, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("unexpected RowDoer sum: got:%f want:%f", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset got for testing with DoColNonZero.
|
||||||
|
got = 0
|
||||||
|
panicked, message = panics(func() {
|
||||||
|
for j := 0; j < c; j++ {
|
||||||
|
m.DoColNonZero(j, fn)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if panicked {
|
||||||
|
t.Errorf("unexpected panic for ColDoer test %d: %q", i, message)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("unexpected ColDoer sum: got:%f want:%f", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -16,6 +16,10 @@ var (
|
|||||||
_ Banded = symBandDense
|
_ Banded = symBandDense
|
||||||
_ RawSymBander = symBandDense
|
_ RawSymBander = symBandDense
|
||||||
_ MutableSymBanded = symBandDense
|
_ MutableSymBanded = symBandDense
|
||||||
|
|
||||||
|
_ NonZeroDoer = symBandDense
|
||||||
|
_ RowNonZeroDoer = symBandDense
|
||||||
|
_ ColNonZeroDoer = symBandDense
|
||||||
)
|
)
|
||||||
|
|
||||||
// SymBandDense represents a symmetric band matrix in dense storage format.
|
// SymBandDense represents a symmetric band matrix in dense storage format.
|
||||||
@@ -41,7 +45,7 @@ type RawSymBander interface {
|
|||||||
// NewSymBandDense creates a new SymBand matrix with n rows and columns. If data == nil,
|
// NewSymBandDense creates a new SymBand matrix with n rows and columns. If data == nil,
|
||||||
// a new slice is allocated for the backing slice. If len(data) == n*(k+1),
|
// a new slice is allocated for the backing slice. If len(data) == n*(k+1),
|
||||||
// data is used as the backing slice, and changes to the elements of the returned
|
// data is used as the backing slice, and changes to the elements of the returned
|
||||||
// BandDense will be reflected in data. If neither of these is true, NewSymBandDense
|
// SymBandDense will be reflected in data. If neither of these is true, NewSymBandDense
|
||||||
// will panic. k must be at least zero and less than n, otherwise NewBandDense will panic.
|
// will panic. k must be at least zero and less than n, otherwise NewBandDense will panic.
|
||||||
//
|
//
|
||||||
// The data must be arranged in row-major order constructed by removing the zeros
|
// The data must be arranged in row-major order constructed by removing the zeros
|
||||||
@@ -126,3 +130,46 @@ func (s *SymBandDense) TBand() Banded {
|
|||||||
func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
|
func (s *SymBandDense) RawSymBand() blas64.SymmetricBand {
|
||||||
return s.mat
|
return s.mat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DoNonZero calls the function fn for each of the non-zero elements of s. The function fn
|
||||||
|
// takes a row/column index and the element value of s at (i, j).
|
||||||
|
func (s *SymBandDense) DoNonZero(fn func(i, j int, v float64)) {
|
||||||
|
for i := 0; i < s.mat.N; i++ {
|
||||||
|
for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
|
||||||
|
v := s.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRowNonZero calls the function fn for each of the non-zero elements of row i of s. The function fn
|
||||||
|
// takes a row/column index and the element value of s at (i, j).
|
||||||
|
func (s *SymBandDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
|
||||||
|
if i < 0 || s.mat.N <= i {
|
||||||
|
panic(ErrRowAccess)
|
||||||
|
}
|
||||||
|
for j := max(0, i-s.mat.K); j < min(s.mat.N, i+s.mat.K+1); j++ {
|
||||||
|
v := s.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoColNonZero calls the function fn for each of the non-zero elements of column j of s. The function fn
|
||||||
|
// takes a row/column index and the element value of s at (i, j).
|
||||||
|
func (s *SymBandDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
|
||||||
|
if j < 0 || s.mat.N <= j {
|
||||||
|
panic(ErrColAccess)
|
||||||
|
}
|
||||||
|
for i := 0; i < s.mat.N; i++ {
|
||||||
|
if i-s.mat.K <= j && j < i+s.mat.K+1 {
|
||||||
|
v := s.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -45,6 +45,7 @@ type RawSymmetricer interface {
|
|||||||
RawSymmetric() blas64.Symmetric
|
RawSymmetric() blas64.Symmetric
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A MutableSymmetric can set elements of a symmetric matrix.
|
||||||
type MutableSymmetric interface {
|
type MutableSymmetric interface {
|
||||||
Symmetric
|
Symmetric
|
||||||
SetSym(i, j int, v float64)
|
SetSym(i, j int, v float64)
|
||||||
|
@@ -14,9 +14,14 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
triDense *TriDense
|
triDense *TriDense
|
||||||
_ Matrix = triDense
|
_ Matrix = triDense
|
||||||
_ Triangular = triDense
|
_ Triangular = triDense
|
||||||
_ RawTriangular = triDense
|
_ RawTriangular = triDense
|
||||||
|
_ MutableTriangular = triDense
|
||||||
|
|
||||||
|
_ NonZeroDoer = triDense
|
||||||
|
_ RowNonZeroDoer = triDense
|
||||||
|
_ ColNonZeroDoer = triDense
|
||||||
)
|
)
|
||||||
|
|
||||||
const badTriCap = "mat: bad capacity for TriDense"
|
const badTriCap = "mat: bad capacity for TriDense"
|
||||||
@@ -28,6 +33,7 @@ type TriDense struct {
|
|||||||
cap int
|
cap int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Triangular represents a triangular matrix. Triangular matrices are always square.
|
||||||
type Triangular interface {
|
type Triangular interface {
|
||||||
Matrix
|
Matrix
|
||||||
// Triangular returns the number of rows/columns in the matrix and its
|
// Triangular returns the number of rows/columns in the matrix and its
|
||||||
@@ -39,10 +45,17 @@ type Triangular interface {
|
|||||||
TTri() Triangular
|
TTri() Triangular
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A RawTriangular can return a view of itself as a BLAS Triangular matrix.
|
||||||
type RawTriangular interface {
|
type RawTriangular interface {
|
||||||
RawTriangular() blas64.Triangular
|
RawTriangular() blas64.Triangular
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A MutableTriangular can set elements of a triangular matrix.
|
||||||
|
type MutableTriangular interface {
|
||||||
|
Triangular
|
||||||
|
SetTri(i, j int, v float64)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_ Matrix = TransposeTri{}
|
_ Matrix = TransposeTri{}
|
||||||
_ Triangular = TransposeTri{}
|
_ Triangular = TransposeTri{}
|
||||||
@@ -455,3 +468,73 @@ func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DoNonZero calls the function fn for each of the non-zero elements of t. The function fn
|
||||||
|
// takes a row/column index and the element value of t at (i, j).
|
||||||
|
func (t *TriDense) DoNonZero(fn func(i, j int, v float64)) {
|
||||||
|
if t.isUpper() {
|
||||||
|
for i := 0; i < t.mat.N; i++ {
|
||||||
|
for j := i; j < t.mat.N; j++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < t.mat.N; i++ {
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoRowNonZero calls the function fn for each of the non-zero elements of row i of t. The function fn
|
||||||
|
// takes a row/column index and the element value of t at (i, j).
|
||||||
|
func (t *TriDense) DoRowNonZero(i int, fn func(i, j int, v float64)) {
|
||||||
|
if i < 0 || t.mat.N <= i {
|
||||||
|
panic(ErrRowAccess)
|
||||||
|
}
|
||||||
|
if t.isUpper() {
|
||||||
|
for j := i; j < t.mat.N; j++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoColNonZero calls the function fn for each of the non-zero elements of column j of t. The function fn
|
||||||
|
// takes a row/column index and the element value of t at (i, j).
|
||||||
|
func (t *TriDense) DoColNonZero(j int, fn func(i, j int, v float64)) {
|
||||||
|
if j < 0 || t.mat.N <= j {
|
||||||
|
panic(ErrColAccess)
|
||||||
|
}
|
||||||
|
if t.isUpper() {
|
||||||
|
for i := 0; i <= j; i++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := j; i < t.mat.N; i++ {
|
||||||
|
v := t.at(i, j)
|
||||||
|
if v != 0 {
|
||||||
|
fn(i, j, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user