mat: add TriDense.SliceTri

This commit is contained in:
Vladimir Chalupecky
2020-03-07 21:23:43 +01:00
committed by Vladimír Chalupecký
parent 5f268d9394
commit 4363550baf
2 changed files with 99 additions and 0 deletions

View File

@@ -600,6 +600,26 @@ func (t *TriDense) ScaleTri(f float64, a Triangular) {
}
}
// SliceTri returns a new Triangular that shares backing data with the receiver.
// The returned matrix starts at {i,i} of the receiver and extends k-i rows and
// columns. The final row and column in the resulting matrix is k-1.
// SliceTri panics with ErrIndexOutOfRange if the slice is outside the capacity
// of the receiver.
func (t *TriDense) SliceTri(i, k int) Triangular {
return t.sliceTri(i, k)
}
func (t *TriDense) sliceTri(i, k int) *TriDense {
if i < 0 || t.cap < i || k < i || t.cap < k {
panic(ErrIndexOutOfRange)
}
v := *t
v.mat.Data = t.mat.Data[i*t.mat.Stride+i : (k-1)*t.mat.Stride+k]
v.mat.N = k - i
v.cap = t.cap - i
return &v
}
// Trace returns the trace of the matrix.
func (t *TriDense) Trace() float64 {
// TODO(btracey): could use internal asm sum routine.

View File

@@ -514,6 +514,85 @@ func TestCopySymIntoTriangle(t *testing.T) {
}
}
func TestTriSliceTri(t *testing.T) {
rnd := rand.New(rand.NewSource(1))
for cas, test := range []struct {
n, start1, span1, start2, span2 int
}{
{10, 0, 10, 0, 10},
{10, 0, 8, 0, 8},
{10, 2, 8, 0, 6},
{10, 2, 7, 4, 2},
{10, 2, 6, 0, 5},
{10, 2, 3, 1, 7},
} {
n := test.n
for _, kind := range []TriKind{Upper, Lower} {
tri := NewTriDense(n, kind, nil)
if kind == Upper {
for i := 0; i < n; i++ {
for j := i; j < n; j++ {
tri.SetTri(i, j, rnd.Float64())
}
}
} else {
for i := 0; i < n; i++ {
for j := 0; j <= i; j++ {
tri.SetTri(i, j, rnd.Float64())
}
}
}
start1 := test.start1
span1 := test.span1
v1 := tri.SliceTri(start1, start1+span1).(*TriDense)
if kind == Upper {
for i := 0; i < span1; i++ {
for j := i; j < span1; j++ {
if v1.At(i, j) != tri.At(start1+i, start1+j) {
t.Errorf("Case %d,upper: view mismatch at %v,%v", cas, i, j)
}
}
}
} else {
for i := 0; i < span1; i++ {
for j := 0; j <= i; j++ {
if v1.At(i, j) != tri.At(start1+i, start1+j) {
t.Errorf("Case %d,lower: view mismatch at %v,%v", cas, i, j)
}
}
}
}
start2 := test.start2
span2 := test.span2
v2 := v1.SliceTri(start2, start2+span2).(*TriDense)
if kind == Upper {
for i := 0; i < span2; i++ {
for j := i; j < span2; j++ {
if v2.At(i, j) != tri.At(start1+start2+i, start1+start2+j) {
t.Errorf("Case %d,upper: second view mismatch at %v,%v", cas, i, j)
}
}
}
} else {
for i := 0; i < span1; i++ {
for j := 0; j <= i; j++ {
if v1.At(i, j) != tri.At(start1+i, start1+j) {
t.Errorf("Case %d,lower: second view mismatch at %v,%v", cas, i, j)
}
}
}
}
v2.SetTri(span2-1, span2-1, -123.45)
if tri.At(start1+start2+span2-1, start1+start2+span2-1) != -123.45 {
t.Errorf("Case %d: write to view not reflected in original", cas)
}
}
}
}
var triSumForBench float64
func BenchmarkTriSum(b *testing.B) {