mat: add TriBandDense.SolveVec and SolveVecTo

This commit is contained in:
Vladimir Chalupecky
2020-10-10 15:26:37 +02:00
committed by Vladimír Chalupecký
parent 40b831e267
commit d520c6cf9e
2 changed files with 222 additions and 0 deletions

View File

@@ -5,8 +5,11 @@
package mat
import (
"math"
"gonum.org/v1/gonum/blas"
"gonum.org/v1/gonum/blas/blas64"
"gonum.org/v1/gonum/lapack/lapack64"
)
var (
@@ -458,6 +461,64 @@ func (t *TriBandDense) Trace() float64 {
return tr
}
// SolveTo solves a triangular system T * X = B or Tᵀ * X = B where T is an
// n×n triangular band matrix represented by the receiver and B is a given
// n×nrhs matrix. If T is non-singular, the result will be stored into dst and
// nil will be returned. If T is singular, the contents of dst will be undefined
// and a Condition error will be returned.
func (t *TriBandDense) SolveTo(dst *Dense, trans bool, b Matrix) error {
n, nrhs := b.Dims()
if n != t.mat.N {
panic(ErrShape)
}
if b, ok := b.(RawMatrixer); ok && dst != b {
dst.checkOverlap(b.RawMatrix())
}
dst.reuseAsNonZeroed(n, nrhs)
if dst != b {
dst.Copy(b)
}
var ok bool
if trans {
ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.mat)
} else {
ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.mat)
}
if !ok {
return Condition(math.Inf(1))
}
return nil
}
// SolveVecTo solves a triangular system T * x = b or Tᵀ * x = b where T is an
// n×n triangular band matrix represented by the receiver and b is a given
// n-vector. If T is non-singular, the result will be stored into dst and nil
// will be returned. If T is singular, the contents of dst will be undefined and
// a Condition error will be returned.
func (t *TriBandDense) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
n, nrhs := b.Dims()
if n != t.mat.N || nrhs != 1 {
panic(ErrShape)
}
if b, ok := b.(RawVectorer); ok && dst != b {
dst.checkOverlap(b.RawVector())
}
dst.reuseAsNonZeroed(n)
if dst != b {
dst.CopyVec(b)
}
var ok bool
if trans {
ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.asGeneral())
} else {
ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.asGeneral())
}
if !ok {
return Condition(math.Inf(1))
}
return nil
}
func copySymBandIntoTriBand(dst *TriBandDense, s SymBanded) {
n, k, upper := dst.TriBand()
ns, ks := s.SymBand()

View File

@@ -5,6 +5,7 @@
package mat
import (
"fmt"
"reflect"
"testing"
@@ -414,3 +415,163 @@ func TestTriBandDiagView(t *testing.T) {
testDiagView(t, cas, test)
}
}
func TestTriBandDenseSolveTo(t *testing.T) {
t.Parallel()
const tol = 1e-15
for tc, test := range []struct {
a *TriBandDense
b *Dense
}{
{
a: NewTriBandDense(5, 2, Upper, []float64{
-0.34, -0.49, -0.51,
-0.25, -0.5, 1.03,
-1.1, 0.3, -0.82,
1.69, 0.69, -2.22,
-0.62, 1.22, -0.85,
}),
b: NewDense(5, 2, []float64{
0.44, 1.34,
0.07, -1.45,
-0.32, -0.88,
-0.09, -0.15,
-1.17, -0.19,
}),
},
{
a: NewTriBandDense(5, 2, Lower, []float64{
0, 0, -0.34,
0, -0.49, -0.25,
-0.51, -0.5, -1.1,
1.03, 0.3, 1.69,
-0.82, 0.69, -0.62,
}),
b: NewDense(5, 2, []float64{
0.44, 1.34,
0.07, -1.45,
-0.32, -0.88,
-0.09, -0.15,
-1.17, -0.19,
}),
},
} {
a := test.a
for _, trans := range []bool{false, true} {
for _, dstSameAsB := range []bool{false, true} {
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
n, nrhs := test.b.Dims()
var dst Dense
var err error
if dstSameAsB {
dst = *NewDense(n, nrhs, nil)
dst.Copy(test.b)
err = a.SolveTo(&dst, trans, &dst)
} else {
tmp := NewDense(n, nrhs, nil)
tmp.Copy(test.b)
err = a.SolveTo(&dst, trans, asBasicMatrix(tmp))
}
if err != nil {
t.Fatalf("%v: unexpected error from SolveTo", name)
}
var resid Dense
if trans {
resid.Mul(a.T(), &dst)
} else {
resid.Mul(a, &dst)
}
resid.Sub(&resid, test.b)
diff := Norm(&resid, 1)
if diff > tol {
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
}
}
}
}
}
func TestTriBandDenseSolveVecTo(t *testing.T) {
t.Parallel()
const tol = 1e-15
for tc, test := range []struct {
a *TriBandDense
b *VecDense
}{
{
a: NewTriBandDense(5, 2, Upper, []float64{
-0.34, -0.49, -0.51,
-0.25, -0.5, 1.03,
-1.1, 0.3, -0.82,
1.69, 0.69, -2.22,
-0.62, 1.22, -0.85,
}),
b: NewVecDense(5, []float64{
0.44,
0.07,
-0.32,
-0.09,
-1.17,
}),
},
{
a: NewTriBandDense(5, 2, Lower, []float64{
0, 0, -0.34,
0, -0.49, -0.25,
-0.51, -0.5, -1.1,
1.03, 0.3, 1.69,
-0.82, 0.69, -0.62,
}),
b: NewVecDense(5, []float64{
0.44,
0.07,
-0.32,
-0.09,
-1.17,
}),
},
} {
a := test.a
for _, trans := range []bool{false, true} {
for _, dstSameAsB := range []bool{false, true} {
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
n, _ := test.b.Dims()
var dst VecDense
var err error
if dstSameAsB {
dst = *NewVecDense(n, nil)
dst.CopyVec(test.b)
err = a.SolveVecTo(&dst, trans, &dst)
} else {
tmp := NewVecDense(n, nil)
tmp.CopyVec(test.b)
err = a.SolveVecTo(&dst, trans, asBasicVector(tmp))
}
if err != nil {
t.Fatalf("%v: unexpected error from SolveVecTo", name)
}
var resid VecDense
if trans {
resid.MulVec(a.T(), &dst)
} else {
resid.MulVec(a, &dst)
}
resid.SubVec(&resid, test.b)
diff := Norm(&resid, 1)
if diff > tol {
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
}
}
}
}
}