mirror of
https://github.com/gonum/gonum.git
synced 2025-10-06 15:47:01 +08:00
mat: Add Scale method for TriDense and Cholesky (#267)
* mat: Add Scale method for TriDense and Cholesky
This commit is contained in:
@@ -308,6 +308,31 @@ func (c *Cholesky) InverseTo(s *SymDense) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scale multiplies the original matrix A by a positive constant using
|
||||||
|
// its Cholesky decomposition, storing the result in-place into the receiver.
|
||||||
|
// That is, if the original Cholesky factorization is
|
||||||
|
// U^T * U = A
|
||||||
|
// the updated factorization is
|
||||||
|
// U'^T * U' = f A = A'
|
||||||
|
// Scale panics if the constant is non-positive, or if the receiver is non-zero
|
||||||
|
// and is of a different Size from the input.
|
||||||
|
func (c *Cholesky) Scale(f float64, orig *Cholesky) {
|
||||||
|
if !orig.valid() {
|
||||||
|
panic(badCholesky)
|
||||||
|
}
|
||||||
|
if f <= 0 {
|
||||||
|
panic("cholesky: scaling by a non-positive constant")
|
||||||
|
}
|
||||||
|
n := orig.Size()
|
||||||
|
if c.isZero() {
|
||||||
|
c.chol = NewTriDense(n, Upper, nil)
|
||||||
|
} else if c.chol.mat.N != n {
|
||||||
|
panic(ErrShape)
|
||||||
|
}
|
||||||
|
c.chol.ScaleTri(math.Sqrt(f), orig.chol)
|
||||||
|
c.cond = orig.cond // Scaling by a positive constant does not change the condition number.
|
||||||
|
}
|
||||||
|
|
||||||
// SymRankOne performs a rank-1 update of the original matrix A and refactorizes
|
// SymRankOne performs a rank-1 update of the original matrix A and refactorizes
|
||||||
// its Cholesky factorization, storing the result into the receiver. That is, if
|
// its Cholesky factorization, storing the result into the receiver. That is, if
|
||||||
// in the original Cholesky factorization
|
// in the original Cholesky factorization
|
||||||
|
@@ -10,6 +10,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"gonum.org/v1/gonum/blas/testblas"
|
"gonum.org/v1/gonum/blas/testblas"
|
||||||
|
"gonum.org/v1/gonum/floats"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCholesky(t *testing.T) {
|
func TestCholesky(t *testing.T) {
|
||||||
@@ -421,6 +422,65 @@ func TestCholeskySymRankOne(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCholeskyScale(t *testing.T) {
|
||||||
|
for cas, test := range []struct {
|
||||||
|
a *SymDense
|
||||||
|
f float64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
a: NewSymDense(3, []float64{
|
||||||
|
4, 1, 1,
|
||||||
|
0, 2, 3,
|
||||||
|
0, 0, 6,
|
||||||
|
}),
|
||||||
|
f: 0.5,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
var chol Cholesky
|
||||||
|
ok := chol.Factorize(test.a)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("Case %v: bad test, Cholesky factorization failed", cas)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare the update to a new Cholesky to an update in-place.
|
||||||
|
var cholUpdate Cholesky
|
||||||
|
cholUpdate.Scale(test.f, &chol)
|
||||||
|
chol.Scale(test.f, &chol)
|
||||||
|
if !equalChol(&chol, &cholUpdate) {
|
||||||
|
t.Errorf("Case %d: cholesky mismatch new receiver", cas)
|
||||||
|
}
|
||||||
|
var sym SymDense
|
||||||
|
chol.ToSym(&sym)
|
||||||
|
var comp SymDense
|
||||||
|
comp.ScaleSym(test.f, test.a)
|
||||||
|
if !EqualApprox(&comp, &sym, 1e-14) {
|
||||||
|
t.Errorf("Case %d: cholesky reconstruction doesn't match scaled matrix", cas)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cholTest Cholesky
|
||||||
|
cholTest.Factorize(&comp)
|
||||||
|
if !equalApproxChol(&cholTest, &chol, 1e-12, 1e-12) {
|
||||||
|
t.Errorf("Case %d: cholesky mismatch with scaled matrix. %v, %v", cas, cholTest.cond, chol.cond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalApproxChol checks that the two Cholesky decompositions are equal.
|
||||||
|
func equalChol(a, b *Cholesky) bool {
|
||||||
|
return Equal(a.chol, b.chol) && a.cond == b.cond
|
||||||
|
}
|
||||||
|
|
||||||
|
// equalApproxChol checks that the two Cholesky decompositions are approximately
|
||||||
|
// the same with the given tolerance on equality for the Triangular component and
|
||||||
|
// condition.
|
||||||
|
func equalApproxChol(a, b *Cholesky, matTol, condTol float64) bool {
|
||||||
|
if !EqualApprox(a.chol, b.chol, matTol) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return floats.EqualWithinAbsOrRel(a.cond, b.cond, condTol, condTol)
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkCholeskySmall(b *testing.B) {
|
func BenchmarkCholeskySmall(b *testing.B) {
|
||||||
benchmarkCholesky(b, 2)
|
benchmarkCholesky(b, 2)
|
||||||
}
|
}
|
||||||
|
@@ -131,6 +131,32 @@ func legalTypeSym(a Matrix) bool {
|
|||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// legalTypeTri returns whether a is a Triangular.
|
||||||
|
func legalTypeTri(a Matrix) bool {
|
||||||
|
_, ok := a.(Triangular)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// legalTypeTriLower returns whether a is a Triangular with kind == Lower.
|
||||||
|
func legalTypeTriLower(a Matrix) bool {
|
||||||
|
t, ok := a.(Triangular)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, kind := t.Triangle()
|
||||||
|
return kind == Lower
|
||||||
|
}
|
||||||
|
|
||||||
|
// legalTypeTriUpper returns whether a is a Triangular with kind == Upper.
|
||||||
|
func legalTypeTriUpper(a Matrix) bool {
|
||||||
|
t, ok := a.(Triangular)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, kind := t.Triangle()
|
||||||
|
return kind == Upper
|
||||||
|
}
|
||||||
|
|
||||||
// legalTypesSym returns whether both input arguments are Symmetric.
|
// legalTypesSym returns whether both input arguments are Symmetric.
|
||||||
func legalTypesSym(a, b Matrix) bool {
|
func legalTypesSym(a, b Matrix) bool {
|
||||||
if _, ok := a.(Symmetric); !ok {
|
if _, ok := a.(Symmetric); !ok {
|
||||||
|
@@ -416,6 +416,51 @@ func (t *TriDense) MulTri(a, b Triangular) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScaleTri multiplies the elements of a by f, placing the result in the receiver.
|
||||||
|
// If the receiver is non-zero, the size and kind of the receiver must match
|
||||||
|
// the input, or ScaleTri will panic.
|
||||||
|
func (t *TriDense) ScaleTri(f float64, a Triangular) {
|
||||||
|
n, kind := a.Triangle()
|
||||||
|
t.reuseAs(n, kind)
|
||||||
|
|
||||||
|
// TODO(btracey): Improve the set of fast-paths.
|
||||||
|
switch a := a.(type) {
|
||||||
|
case RawTriangular:
|
||||||
|
amat := a.RawTriangular()
|
||||||
|
if kind == Upper {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
ts := t.mat.Data[i*t.mat.Stride+i : i*t.mat.Stride+n]
|
||||||
|
as := amat.Data[i*amat.Stride+i : i*amat.Stride+n]
|
||||||
|
for i, v := range as {
|
||||||
|
ts[i] = v * f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
ts := t.mat.Data[i*t.mat.Stride : i*t.mat.Stride+i+1]
|
||||||
|
as := amat.Data[i*amat.Stride : i*amat.Stride+i+1]
|
||||||
|
for i, v := range as {
|
||||||
|
ts[i] = v * f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
isUpper := kind == Upper
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
if isUpper {
|
||||||
|
for j := i; j < n; j++ {
|
||||||
|
t.set(i, j, f*a.At(i, j))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for j := 0; j <= i; j++ {
|
||||||
|
t.set(i, j, f*a.At(i, j))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// copySymIntoTriangle copies a symmetric matrix into a TriDense
|
// copySymIntoTriangle copies a symmetric matrix into a TriDense
|
||||||
func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
func copySymIntoTriangle(t *TriDense, s Symmetric) {
|
||||||
n, upper := t.Triangle()
|
n, upper := t.Triangle()
|
||||||
|
@@ -319,6 +319,23 @@ func TestTriMul(t *testing.T) {
|
|||||||
testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesUpper, legalSizeTriMul, 1e-14)
|
testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesUpper, legalSizeTriMul, 1e-14)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScaleTri(t *testing.T) {
|
||||||
|
for _, f := range []float64{0.5, 1, 3} {
|
||||||
|
method := func(receiver, a Matrix) {
|
||||||
|
type ScaleTrier interface {
|
||||||
|
ScaleTri(f float64, a Triangular)
|
||||||
|
}
|
||||||
|
rd := receiver.(ScaleTrier)
|
||||||
|
rd.ScaleTri(f, a.(Triangular))
|
||||||
|
}
|
||||||
|
denseComparison := func(receiver, a *Dense) {
|
||||||
|
receiver.Scale(f, a)
|
||||||
|
}
|
||||||
|
testOneInput(t, "ScaleTriUpper", NewTriDense(3, Upper, nil), method, denseComparison, legalTypeTriUpper, isSquare, 1e-14)
|
||||||
|
testOneInput(t, "ScaleTriLower", NewTriDense(3, Lower, nil), method, denseComparison, legalTypeTriLower, isSquare, 1e-14)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCopySymIntoTriangle(t *testing.T) {
|
func TestCopySymIntoTriangle(t *testing.T) {
|
||||||
nan := math.NaN()
|
nan := math.NaN()
|
||||||
for tc, test := range []struct {
|
for tc, test := range []struct {
|
||||||
|
Reference in New Issue
Block a user