mirror of
https://github.com/gonum/gonum.git
synced 2025-10-05 15:16:59 +08:00
mat: add BandCholesky type
This commit is contained in:

committed by
Vladimír Chalupecký

parent
aea9ac7fa3
commit
971fc50f31
220
mat/cholesky.go
220
mat/cholesky.go
@@ -20,6 +20,11 @@ const (
|
||||
var (
|
||||
_ Matrix = (*Cholesky)(nil)
|
||||
_ Symmetric = (*Cholesky)(nil)
|
||||
|
||||
_ Matrix = (*BandCholesky)(nil)
|
||||
_ Symmetric = (*BandCholesky)(nil)
|
||||
_ Banded = (*BandCholesky)(nil)
|
||||
_ SymBanded = (*BandCholesky)(nil)
|
||||
)
|
||||
|
||||
// Cholesky is a symmetric positive definite matrix represented by its
|
||||
@@ -100,7 +105,7 @@ func (c *Cholesky) At(i, j int) float64 {
|
||||
return val
|
||||
}
|
||||
|
||||
// T returns the the receiver, the transpose of a symmetric matrix.
|
||||
// T returns the receiver, the transpose of a symmetric matrix.
|
||||
func (c *Cholesky) T() Matrix {
|
||||
return c
|
||||
}
|
||||
@@ -264,7 +269,7 @@ func (a *Cholesky) SolveCholTo(dst *Dense, b *Cholesky) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SolveVecTo finds the vector X that solves A * x = b where A is represented
|
||||
// SolveVecTo finds the vector x that solves A * x = b where A is represented
|
||||
// by the Cholesky decomposition. The result is stored in-place into
|
||||
// dst.
|
||||
func (c *Cholesky) SolveVecTo(dst *VecDense, b Vector) error {
|
||||
@@ -697,3 +702,214 @@ func (c *Cholesky) SymRankOne(orig *Cholesky, alpha float64, x Vector) (ok bool)
|
||||
func (c *Cholesky) valid() bool {
|
||||
return c.chol != nil && !c.chol.IsEmpty()
|
||||
}
|
||||
|
||||
// BandCholesky is a symmetric positive-definite band matrix represented by its
|
||||
// Cholesky decomposition.
|
||||
//
|
||||
// Note that this matrix representation is useful for certain operations, in
|
||||
// particular finding solutions to linear equations. It is very inefficient at
|
||||
// other operations, in particular At is slow.
|
||||
//
|
||||
// BandCholesky methods may only be called on a value that has been successfully
|
||||
// initialized by a call to Factorize that has returned true. Calls to methods
|
||||
// of an unsuccessful Cholesky factorization will panic.
|
||||
type BandCholesky struct {
|
||||
// The chol pointer must never be retained as a pointer outside the Cholesky
|
||||
// struct, either by returning chol outside the struct or by setting it to
|
||||
// a pointer coming from outside. The same prohibition applies to the data
|
||||
// slice within chol.
|
||||
chol *TriBandDense
|
||||
cond float64
|
||||
}
|
||||
|
||||
// Factorize calculates the Cholesky decomposition of the matrix A and returns
|
||||
// whether the matrix is positive definite. If Factorize returns false, the
|
||||
// factorization must not be used.
|
||||
func (ch *BandCholesky) Factorize(a SymBanded) (ok bool) {
|
||||
n, k := a.SymBand()
|
||||
if ch.chol == nil {
|
||||
ch.chol = NewTriBandDense(n, k, Upper, nil)
|
||||
} else {
|
||||
ch.chol.Reset()
|
||||
ch.chol.ReuseAsTriBand(n, k, Upper)
|
||||
}
|
||||
copySymBandIntoTriBand(ch.chol, a)
|
||||
cSym := blas64.SymmetricBand{
|
||||
Uplo: blas.Upper,
|
||||
N: n,
|
||||
K: k,
|
||||
Data: ch.chol.RawTriBand().Data,
|
||||
Stride: ch.chol.RawTriBand().Stride,
|
||||
}
|
||||
_, ok = lapack64.Pbtrf(cSym)
|
||||
if !ok {
|
||||
ch.Reset()
|
||||
return false
|
||||
}
|
||||
work := getFloats(3*n, false)
|
||||
iwork := getInts(n, false)
|
||||
aNorm := lapack64.Lansb(CondNorm, cSym, work)
|
||||
ch.cond = 1 / lapack64.Pbcon(cSym, aNorm, work, iwork)
|
||||
putInts(iwork)
|
||||
putFloats(work)
|
||||
return true
|
||||
}
|
||||
|
||||
// SolveTo finds the matrix X that solves A * X = B where A is represented by
|
||||
// the Cholesky decomposition. The result is stored in-place into dst.
|
||||
func (ch *BandCholesky) SolveTo(dst *Dense, b Matrix) error {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
br, bc := b.Dims()
|
||||
if br != ch.chol.mat.N {
|
||||
panic(ErrShape)
|
||||
}
|
||||
dst.reuseAsNonZeroed(br, bc)
|
||||
if b != dst {
|
||||
dst.Copy(b)
|
||||
}
|
||||
lapack64.Pbtrs(ch.chol.mat, dst.mat)
|
||||
if ch.cond > ConditionTolerance {
|
||||
return Condition(ch.cond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SolveVecTo finds the vector x that solves A * x = b where A is represented by
|
||||
// the Cholesky decomposition. The result is stored in-place into dst.
|
||||
func (ch *BandCholesky) SolveVecTo(dst *VecDense, b Vector) error {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
n := ch.chol.mat.N
|
||||
if br, bc := b.Dims(); br != n || bc != 1 {
|
||||
panic(ErrShape)
|
||||
}
|
||||
if b, ok := b.(RawVectorer); ok && dst != b {
|
||||
dst.checkOverlap(b.RawVector())
|
||||
}
|
||||
dst.reuseAsNonZeroed(n)
|
||||
if dst != b {
|
||||
dst.CopyVec(b)
|
||||
}
|
||||
lapack64.Pbtrs(ch.chol.mat, dst.asGeneral())
|
||||
if ch.cond > ConditionTolerance {
|
||||
return Condition(ch.cond)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cond returns the condition number of the factorized matrix.
|
||||
func (ch *BandCholesky) Cond() float64 {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
return ch.cond
|
||||
}
|
||||
|
||||
// Reset resets the factorization so that it can be reused as the receiver of
|
||||
// a dimensionally restricted operation.
|
||||
func (ch *BandCholesky) Reset() {
|
||||
if ch.chol != nil {
|
||||
ch.chol.Reset()
|
||||
}
|
||||
ch.cond = math.Inf(1)
|
||||
}
|
||||
|
||||
// Dims returns the dimensions of the matrix.
|
||||
func (ch *BandCholesky) Dims() (r, c int) {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
r, c = ch.chol.Dims()
|
||||
return r, c
|
||||
}
|
||||
|
||||
// At returns the element at row i, column j.
|
||||
func (ch *BandCholesky) At(i, j int) float64 {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
n, k, _ := ch.chol.TriBand()
|
||||
if uint(i) >= uint(n) {
|
||||
panic(ErrRowAccess)
|
||||
}
|
||||
if uint(j) >= uint(n) {
|
||||
panic(ErrColAccess)
|
||||
}
|
||||
|
||||
if i > j {
|
||||
i, j = j, i
|
||||
}
|
||||
if j-i > k {
|
||||
return 0
|
||||
}
|
||||
var aij float64
|
||||
for k := max(0, j-k); k <= i; k++ {
|
||||
aij += ch.chol.at(k, i) * ch.chol.at(k, j)
|
||||
}
|
||||
return aij
|
||||
}
|
||||
|
||||
// T returns the receiver, the transpose of a symmetric matrix.
|
||||
func (ch *BandCholesky) T() Matrix {
|
||||
return ch
|
||||
}
|
||||
|
||||
// TBand returns the receiver, the transpose of a symmetric band matrix.
|
||||
func (ch *BandCholesky) TBand() Banded {
|
||||
return ch
|
||||
}
|
||||
|
||||
// Symmetric implements the Symmetric interface and returns the number of rows
|
||||
// in the matrix (this is also the number of columns).
|
||||
func (ch *BandCholesky) Symmetric() int {
|
||||
n, _ := ch.chol.Triangle()
|
||||
return n
|
||||
}
|
||||
|
||||
// Bandwidth returns the lower and upper bandwidth values for the matrix.
|
||||
// The total bandwidth of the matrix is kl+ku+1.
|
||||
func (ch *BandCholesky) Bandwidth() (kl, ku int) {
|
||||
_, k, _ := ch.chol.TriBand()
|
||||
return k, k
|
||||
}
|
||||
|
||||
// SymBand returns the number of rows/columns in the matrix, and the size of the
|
||||
// bandwidth. The total bandwidth of the matrix is 2*k+1.
|
||||
func (ch *BandCholesky) SymBand() (n, k int) {
|
||||
n, k, _ = ch.chol.TriBand()
|
||||
return n, k
|
||||
}
|
||||
|
||||
// IsEmpty returns whether the receiver is empty. Empty matrices can be the
|
||||
// receiver for dimensionally restricted operations. The receiver can be emptied
|
||||
// using Reset.
|
||||
func (ch *BandCholesky) IsEmpty() bool {
|
||||
return ch == nil || ch.chol.IsEmpty()
|
||||
}
|
||||
|
||||
// Det returns the determinant of the matrix that has been factorized.
|
||||
func (ch *BandCholesky) Det() float64 {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
return math.Exp(ch.LogDet())
|
||||
}
|
||||
|
||||
// LogDet returns the log of the determinant of the matrix that has been factorized.
|
||||
func (ch *BandCholesky) LogDet() float64 {
|
||||
if !ch.valid() {
|
||||
panic(badCholesky)
|
||||
}
|
||||
var det float64
|
||||
for i := 0; i < ch.chol.mat.N; i++ {
|
||||
det += 2 * math.Log(ch.chol.mat.Data[i*ch.chol.mat.Stride])
|
||||
}
|
||||
return det
|
||||
}
|
||||
|
||||
func (ch *BandCholesky) valid() bool {
|
||||
return ch.chol != nil && !ch.chol.IsEmpty()
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@
|
||||
package mat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"testing"
|
||||
@@ -707,3 +708,219 @@ func BenchmarkCholeskyInverseTo(b *testing.B) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandCholeskySolveTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
nrhs = 4
|
||||
tol = 1e-14
|
||||
)
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, n := range []int{1, 2, 3, 5, 10} {
|
||||
for _, k := range []int{0, 1, n / 2, n - 1} {
|
||||
k := min(k, n-1)
|
||||
|
||||
a := NewSymBandDense(n, k, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
a.SetSymBand(i, i, rnd.Float64()+float64(n))
|
||||
for j := i + 1; j < min(i+k+1, n); j++ {
|
||||
a.SetSymBand(i, j, rnd.Float64())
|
||||
}
|
||||
}
|
||||
|
||||
want := NewDense(n, nrhs, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < nrhs; j++ {
|
||||
want.Set(i, j, rnd.NormFloat64())
|
||||
}
|
||||
}
|
||||
var b Dense
|
||||
b.Mul(a, want)
|
||||
|
||||
for _, typ := range []SymBanded{a, (*basicSymBanded)(a)} {
|
||||
name := fmt.Sprintf("Case n=%d,k=%d,type=%T,nrhs=%d", n, k, typ, nrhs)
|
||||
|
||||
var chol BandCholesky
|
||||
ok := chol.Factorize(typ)
|
||||
if !ok {
|
||||
t.Fatalf("%v: Factorize failed", name)
|
||||
}
|
||||
|
||||
var got Dense
|
||||
err := chol.SolveTo(&got, &b)
|
||||
if err != nil {
|
||||
t.Errorf("%v: unexpected error from SolveTo: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var resid Dense
|
||||
resid.Sub(want, &got)
|
||||
diff := Norm(&resid, math.Inf(1))
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected solution; diff=%v", name, diff)
|
||||
}
|
||||
|
||||
got.Copy(&b)
|
||||
err = chol.SolveTo(&got, &got)
|
||||
if err != nil {
|
||||
t.Errorf("%v: unexpected error from SolveTo when dst==b: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resid.Sub(want, &got)
|
||||
diff = Norm(&resid, math.Inf(1))
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected solution when dst==b; diff=%v", name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandCholeskySolveVecTo(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const tol = 1e-14
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, n := range []int{1, 2, 3, 5, 10} {
|
||||
for _, k := range []int{0, 1, n / 2, n - 1} {
|
||||
k := min(k, n-1)
|
||||
|
||||
a := NewSymBandDense(n, k, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
a.SetSymBand(i, i, rnd.Float64()+float64(n))
|
||||
for j := i + 1; j < min(i+k+1, n); j++ {
|
||||
a.SetSymBand(i, j, rnd.Float64())
|
||||
}
|
||||
}
|
||||
|
||||
want := NewVecDense(n, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
want.SetVec(i, rnd.NormFloat64())
|
||||
}
|
||||
var b VecDense
|
||||
b.MulVec(a, want)
|
||||
|
||||
for _, typ := range []SymBanded{a, (*basicSymBanded)(a)} {
|
||||
name := fmt.Sprintf("Case n=%d,k=%d,type=%T", n, k, typ)
|
||||
|
||||
var chol BandCholesky
|
||||
ok := chol.Factorize(typ)
|
||||
if !ok {
|
||||
t.Fatalf("%v: Factorize failed", name)
|
||||
}
|
||||
|
||||
var got VecDense
|
||||
err := chol.SolveVecTo(&got, &b)
|
||||
if err != nil {
|
||||
t.Errorf("%v: unexpected error from SolveVecTo: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var resid VecDense
|
||||
resid.SubVec(want, &got)
|
||||
diff := Norm(&resid, math.Inf(1))
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected solution; diff=%v", name, diff)
|
||||
}
|
||||
|
||||
got.CopyVec(&b)
|
||||
err = chol.SolveVecTo(&got, &got)
|
||||
if err != nil {
|
||||
t.Errorf("%v: unexpected error from SolveVecTo when dst==b: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
resid.SubVec(want, &got)
|
||||
diff = Norm(&resid, math.Inf(1))
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected solution when dst==b; diff=%v", name, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandCholeskyAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const tol = 1e-14
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, n := range []int{1, 2, 3, 5, 10} {
|
||||
for _, k := range []int{0, 1, n / 2, n - 1} {
|
||||
k := min(k, n-1)
|
||||
name := fmt.Sprintf("Case n=%d,k=%d", n, k)
|
||||
|
||||
a := NewSymBandDense(n, k, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
a.SetSymBand(i, i, rnd.Float64()+float64(n))
|
||||
for j := i + 1; j < min(i+k+1, n); j++ {
|
||||
a.SetSymBand(i, j, rnd.Float64())
|
||||
}
|
||||
}
|
||||
|
||||
var chol BandCholesky
|
||||
ok := chol.Factorize(a)
|
||||
if !ok {
|
||||
t.Fatalf("%v: Factorize failed", name)
|
||||
}
|
||||
|
||||
resid := NewDense(n, n, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
for j := 0; j < n; j++ {
|
||||
resid.Set(i, j, math.Abs(a.At(i, j)-chol.At(i, j)))
|
||||
}
|
||||
}
|
||||
diff := Norm(resid, math.Inf(1))
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected result; diff=%v, want<=%v", name, diff, tol)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBandCholeskyDet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const tol = 1e-14
|
||||
rnd := rand.New(rand.NewSource(1))
|
||||
for _, n := range []int{1, 2, 3, 5, 10} {
|
||||
for _, k := range []int{0, 1, n / 2, n - 1} {
|
||||
k := min(k, n-1)
|
||||
name := fmt.Sprintf("Case n=%d,k=%d", n, k)
|
||||
|
||||
a := NewSymBandDense(n, k, nil)
|
||||
aSym := NewSymDense(n, nil)
|
||||
for i := 0; i < n; i++ {
|
||||
aii := rnd.Float64() + float64(n)
|
||||
a.SetSymBand(i, i, aii)
|
||||
aSym.SetSym(i, i, aii)
|
||||
for j := i + 1; j < min(i+k+1, n); j++ {
|
||||
aij := rnd.Float64()
|
||||
a.SetSymBand(i, j, aij)
|
||||
aSym.SetSym(i, j, aij)
|
||||
}
|
||||
}
|
||||
|
||||
var chol BandCholesky
|
||||
ok := chol.Factorize(a)
|
||||
if !ok {
|
||||
t.Fatalf("%v: Factorize failed", name)
|
||||
}
|
||||
|
||||
var cholDense Cholesky
|
||||
ok = cholDense.Factorize(aSym)
|
||||
if !ok {
|
||||
t.Fatalf("%v: dense Factorize failed", name)
|
||||
}
|
||||
|
||||
want := cholDense.Det()
|
||||
got := chol.Det()
|
||||
diff := math.Abs(got - want)
|
||||
if diff > tol {
|
||||
t.Errorf("%v: unexpected result; got=%v, want=%v (diff=%v)", name, got, want, diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user