mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 23:26:52 +08:00
mat: add NonZeroDoer interfaces and implementations
Also clean up some documentation and missing type checks related to tests for NonZeroDoers.
This commit is contained in:
@@ -521,3 +521,94 @@ func TestTrace(t *testing.T) {
|
||||
}
|
||||
testOneInputFunc(t, "Trace", f, denseComparison, sameAnswerFloat, isAnyType, isSquare)
|
||||
}
|
||||
|
||||
func TestDoer(t *testing.T) {
|
||||
type MatrixDoer interface {
|
||||
Matrix
|
||||
NonZeroDoer
|
||||
RowNonZeroDoer
|
||||
ColNonZeroDoer
|
||||
}
|
||||
ones := func(n int) []float64 {
|
||||
data := make([]float64, n)
|
||||
for i := range data {
|
||||
data[i] = 1
|
||||
}
|
||||
return data
|
||||
}
|
||||
for i, m := range []MatrixDoer{
|
||||
NewTriDense(3, Lower, ones(3*3)),
|
||||
NewTriDense(3, Upper, ones(3*3)),
|
||||
NewBandDense(6, 6, 1, 1, ones(3*6)),
|
||||
NewBandDense(6, 10, 1, 1, ones(3*6)),
|
||||
NewBandDense(10, 6, 1, 1, ones(7*3)),
|
||||
NewSymBandDense(3, 0, ones(3)),
|
||||
NewSymBandDense(3, 1, ones(3*(1+1))),
|
||||
NewSymBandDense(6, 1, ones(6*(1+1))),
|
||||
NewSymBandDense(6, 2, ones(6*(2+1))),
|
||||
} {
|
||||
r, c := m.Dims()
|
||||
|
||||
want := Sum(m)
|
||||
|
||||
// got and fn sum the accessed elements in
|
||||
// the Doer that is being operated on.
|
||||
// fn also tests that the accessed elements
|
||||
// are within the writable areas of the
|
||||
// matrix to check that only valid elements
|
||||
// are operated on.
|
||||
var got float64
|
||||
fn := func(i, j int, v float64) {
|
||||
got += v
|
||||
switch m := m.(type) {
|
||||
case MutableTriangular:
|
||||
m.SetTri(i, j, v)
|
||||
case MutableBanded:
|
||||
m.SetBand(i, j, v)
|
||||
case MutableSymBanded:
|
||||
m.SetSymBand(i, j, v)
|
||||
default:
|
||||
panic("bad test: need mutable type")
|
||||
}
|
||||
}
|
||||
|
||||
panicked, message := panics(func() { m.DoNonZero(fn) })
|
||||
if panicked {
|
||||
t.Errorf("unexpected panic for Doer test %d: %q", i, message)
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("unexpected Doer sum: got:%f want:%f", got, want)
|
||||
}
|
||||
|
||||
// Reset got for testing with DoRowNonZero.
|
||||
got = 0
|
||||
panicked, message = panics(func() {
|
||||
for i := 0; i < r; i++ {
|
||||
m.DoRowNonZero(i, fn)
|
||||
}
|
||||
})
|
||||
if panicked {
|
||||
t.Errorf("unexpected panic for RowDoer test %d: %q", i, message)
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("unexpected RowDoer sum: got:%f want:%f", got, want)
|
||||
}
|
||||
|
||||
// Reset got for testing with DoColNonZero.
|
||||
got = 0
|
||||
panicked, message = panics(func() {
|
||||
for j := 0; j < c; j++ {
|
||||
m.DoColNonZero(j, fn)
|
||||
}
|
||||
})
|
||||
if panicked {
|
||||
t.Errorf("unexpected panic for ColDoer test %d: %q", i, message)
|
||||
continue
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("unexpected ColDoer sum: got:%f want:%f", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user