mirror of
https://github.com/gonum/gonum.git
synced 2025-12-24 13:47:56 +08:00
mat: add TriBandDense.SolveVec and SolveVecTo
This commit is contained in:
committed by
Vladimír Chalupecký
parent
40b831e267
commit
d520c6cf9e
@@ -5,8 +5,11 @@
|
||||
package mat
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"gonum.org/v1/gonum/blas"
|
||||
"gonum.org/v1/gonum/blas/blas64"
|
||||
"gonum.org/v1/gonum/lapack/lapack64"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -458,6 +461,64 @@ func (t *TriBandDense) Trace() float64 {
|
||||
return tr
|
||||
}
|
||||
|
||||
// SolveTo solves a triangular system T * X = B or Tᵀ * X = B where T is an
|
||||
// n×n triangular band matrix represented by the receiver and B is a given
|
||||
// n×nrhs matrix. If T is non-singular, the result will be stored into dst and
|
||||
// nil will be returned. If T is singular, the contents of dst will be undefined
|
||||
// and a Condition error will be returned.
|
||||
func (t *TriBandDense) SolveTo(dst *Dense, trans bool, b Matrix) error {
|
||||
n, nrhs := b.Dims()
|
||||
if n != t.mat.N {
|
||||
panic(ErrShape)
|
||||
}
|
||||
if b, ok := b.(RawMatrixer); ok && dst != b {
|
||||
dst.checkOverlap(b.RawMatrix())
|
||||
}
|
||||
dst.reuseAsNonZeroed(n, nrhs)
|
||||
if dst != b {
|
||||
dst.Copy(b)
|
||||
}
|
||||
var ok bool
|
||||
if trans {
|
||||
ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.mat)
|
||||
} else {
|
||||
ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.mat)
|
||||
}
|
||||
if !ok {
|
||||
return Condition(math.Inf(1))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SolveVecTo solves a triangular system T * x = b or Tᵀ * x = b where T is an
|
||||
// n×n triangular band matrix represented by the receiver and b is a given
|
||||
// n-vector. If T is non-singular, the result will be stored into dst and nil
|
||||
// will be returned. If T is singular, the contents of dst will be undefined and
|
||||
// a Condition error will be returned.
|
||||
func (t *TriBandDense) SolveVecTo(dst *VecDense, trans bool, b Vector) error {
|
||||
n, nrhs := b.Dims()
|
||||
if n != t.mat.N || nrhs != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
if b, ok := b.(RawVectorer); ok && dst != b {
|
||||
dst.checkOverlap(b.RawVector())
|
||||
}
|
||||
dst.reuseAsNonZeroed(n)
|
||||
if dst != b {
|
||||
dst.CopyVec(b)
|
||||
}
|
||||
var ok bool
|
||||
if trans {
|
||||
ok = lapack64.Tbtrs(blas.Trans, t.mat, dst.asGeneral())
|
||||
} else {
|
||||
ok = lapack64.Tbtrs(blas.NoTrans, t.mat, dst.asGeneral())
|
||||
}
|
||||
if !ok {
|
||||
return Condition(math.Inf(1))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func copySymBandIntoTriBand(dst *TriBandDense, s SymBanded) {
|
||||
n, k, upper := dst.TriBand()
|
||||
ns, ks := s.SymBand()
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
package mat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@@ -414,3 +415,163 @@ func TestTriBandDiagView(t *testing.T) {
|
||||
testDiagView(t, cas, test)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriBandDenseSolveTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const tol = 1e-15
|
||||
|
||||
for tc, test := range []struct {
|
||||
a *TriBandDense
|
||||
b *Dense
|
||||
}{
|
||||
{
|
||||
a: NewTriBandDense(5, 2, Upper, []float64{
|
||||
-0.34, -0.49, -0.51,
|
||||
-0.25, -0.5, 1.03,
|
||||
-1.1, 0.3, -0.82,
|
||||
1.69, 0.69, -2.22,
|
||||
-0.62, 1.22, -0.85,
|
||||
}),
|
||||
b: NewDense(5, 2, []float64{
|
||||
0.44, 1.34,
|
||||
0.07, -1.45,
|
||||
-0.32, -0.88,
|
||||
-0.09, -0.15,
|
||||
-1.17, -0.19,
|
||||
}),
|
||||
},
|
||||
{
|
||||
a: NewTriBandDense(5, 2, Lower, []float64{
|
||||
0, 0, -0.34,
|
||||
0, -0.49, -0.25,
|
||||
-0.51, -0.5, -1.1,
|
||||
1.03, 0.3, 1.69,
|
||||
-0.82, 0.69, -0.62,
|
||||
}),
|
||||
b: NewDense(5, 2, []float64{
|
||||
0.44, 1.34,
|
||||
0.07, -1.45,
|
||||
-0.32, -0.88,
|
||||
-0.09, -0.15,
|
||||
-1.17, -0.19,
|
||||
}),
|
||||
},
|
||||
} {
|
||||
a := test.a
|
||||
for _, trans := range []bool{false, true} {
|
||||
for _, dstSameAsB := range []bool{false, true} {
|
||||
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
|
||||
|
||||
n, nrhs := test.b.Dims()
|
||||
var dst Dense
|
||||
var err error
|
||||
if dstSameAsB {
|
||||
dst = *NewDense(n, nrhs, nil)
|
||||
dst.Copy(test.b)
|
||||
err = a.SolveTo(&dst, trans, &dst)
|
||||
} else {
|
||||
tmp := NewDense(n, nrhs, nil)
|
||||
tmp.Copy(test.b)
|
||||
err = a.SolveTo(&dst, trans, asBasicMatrix(tmp))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("%v: unexpected error from SolveTo", name)
|
||||
}
|
||||
|
||||
var resid Dense
|
||||
if trans {
|
||||
resid.Mul(a.T(), &dst)
|
||||
} else {
|
||||
resid.Mul(a, &dst)
|
||||
}
|
||||
resid.Sub(&resid, test.b)
|
||||
diff := Norm(&resid, 1)
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTriBandDenseSolveVecTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const tol = 1e-15
|
||||
|
||||
for tc, test := range []struct {
|
||||
a *TriBandDense
|
||||
b *VecDense
|
||||
}{
|
||||
{
|
||||
a: NewTriBandDense(5, 2, Upper, []float64{
|
||||
-0.34, -0.49, -0.51,
|
||||
-0.25, -0.5, 1.03,
|
||||
-1.1, 0.3, -0.82,
|
||||
1.69, 0.69, -2.22,
|
||||
-0.62, 1.22, -0.85,
|
||||
}),
|
||||
b: NewVecDense(5, []float64{
|
||||
0.44,
|
||||
0.07,
|
||||
-0.32,
|
||||
-0.09,
|
||||
-1.17,
|
||||
}),
|
||||
},
|
||||
{
|
||||
a: NewTriBandDense(5, 2, Lower, []float64{
|
||||
0, 0, -0.34,
|
||||
0, -0.49, -0.25,
|
||||
-0.51, -0.5, -1.1,
|
||||
1.03, 0.3, 1.69,
|
||||
-0.82, 0.69, -0.62,
|
||||
}),
|
||||
b: NewVecDense(5, []float64{
|
||||
0.44,
|
||||
0.07,
|
||||
-0.32,
|
||||
-0.09,
|
||||
-1.17,
|
||||
}),
|
||||
},
|
||||
} {
|
||||
a := test.a
|
||||
for _, trans := range []bool{false, true} {
|
||||
for _, dstSameAsB := range []bool{false, true} {
|
||||
name := fmt.Sprintf("Case %d,trans=%v,dstSameAsB=%v", tc, trans, dstSameAsB)
|
||||
|
||||
n, _ := test.b.Dims()
|
||||
var dst VecDense
|
||||
var err error
|
||||
if dstSameAsB {
|
||||
dst = *NewVecDense(n, nil)
|
||||
dst.CopyVec(test.b)
|
||||
err = a.SolveVecTo(&dst, trans, &dst)
|
||||
} else {
|
||||
tmp := NewVecDense(n, nil)
|
||||
tmp.CopyVec(test.b)
|
||||
err = a.SolveVecTo(&dst, trans, asBasicVector(tmp))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("%v: unexpected error from SolveVecTo", name)
|
||||
}
|
||||
|
||||
var resid VecDense
|
||||
if trans {
|
||||
resid.MulVec(a.T(), &dst)
|
||||
} else {
|
||||
resid.MulVec(a, &dst)
|
||||
}
|
||||
resid.SubVec(&resid, test.b)
|
||||
diff := Norm(&resid, 1)
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected result; diff=%v,want<=%v", name, diff, tol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user