diff --git a/mat/eigen.go b/mat/eigen.go index 69f8eb57..3862a745 100644 --- a/mat/eigen.go +++ b/mat/eigen.go @@ -23,6 +23,43 @@ type EigenSym struct { vectors *Dense } +// Dims returns the dimensions of the matrix. +func (e *EigenSym) Dims() (r, c int) { + n := e.SymmetricDim() + return n, n +} + +// SymmetricDim implements the Symmetric interface. +func (e *EigenSym) SymmetricDim() int { + return len(e.values) +} + +// At returns the element at row i, column j. +// At will panic if the eigenvectors have not been computed. +func (e *EigenSym) At(i, j int) float64 { + if !e.vectorsComputed { + panic(noVectors) + } + n, _ := e.Dims() + if uint(i) >= uint(n) { + panic(ErrRowAccess) + } + if uint(j) >= uint(n) { + panic(ErrColAccess) + } + + var val float64 + for k := 0; k < n; k++ { + val += e.values[k] * e.vectors.at(i, k) * e.vectors.at(j, k) + } + return val +} + +// T returns the receiver, the transpose of a symmetric matrix. +func (e *EigenSym) T() Matrix { + return e +} + // Factorize computes the eigenvalue decomposition of the symmetric matrix a. // The Eigen decomposition is defined as // @@ -31,7 +68,8 @@ type EigenSym struct { // where D is a diagonal matrix containing the eigenvalues of the matrix, and // P is a matrix of the eigenvectors of A. Factorize computes the eigenvalues // in ascending order. If the vectors input argument is false, the eigenvectors -// are not computed. +// are not computed and the factorization cannot be used as a Matrix because +// At will panic. // // Factorize returns whether the decomposition succeeded. If the decomposition // failed, methods that require a successful factorization will panic. diff --git a/mat/eigen_test.go b/mat/eigen_test.go index 6723301e..b28b281d 100644 --- a/mat/eigen_test.go +++ b/mat/eigen_test.go @@ -131,10 +131,11 @@ func cmplxEqualTol(v1, v2 []complex128, tol float64) bool { return true } -func TestSymEigen(t *testing.T) { +func TestEigenSym(t *testing.T) { t.Parallel() + const tol = 1e-14 // Hand coded tests with results from lapack. - for _, test := range []struct { + for cas, test := range []struct { mat *SymDense values []float64 @@ -153,25 +154,26 @@ func TestSymEigen(t *testing.T) { var es EigenSym ok := es.Factorize(test.mat, true) if !ok { - t.Errorf("bad factorization") + t.Errorf("case %d: bad test", cas) + continue } - if !floats.EqualApprox(test.values, es.values, 1e-14) { - t.Errorf("Eigenvalue mismatch") + if !floats.EqualApprox(test.values, es.values, tol) { + t.Errorf("case %d: eigenvalue mismatch", cas) } - if !EqualApprox(test.vectors, es.vectors, 1e-14) { - t.Errorf("Eigenvector mismatch") + if !EqualApprox(test.vectors, es.vectors, tol) { + t.Errorf("case %d: eigenvector mismatch", cas) } var es2 EigenSym es2.Factorize(test.mat, false) - if !floats.EqualApprox(es2.values, es.values, 1e-14) { - t.Errorf("Eigenvalue mismatch when no vectors computed") + if !floats.EqualApprox(es2.values, es.values, tol) { + t.Errorf("case %d: eigenvalue mismatch when no vectors computed", cas) } } // Randomized tests rnd := rand.New(rand.NewSource(1)) - for _, n := range []int{3, 5, 10, 70} { + for _, n := range []int{1, 2, 3, 5, 10, 70} { for cas := 0; cas < 10; cas++ { a := make([]float64, n*n) for i := range a { @@ -181,12 +183,21 @@ func TestSymEigen(t *testing.T) { var es EigenSym ok := es.Factorize(s, true) if !ok { - t.Errorf("Bad test") + t.Errorf("n=%d,cas=%d: bad test", n, cas) + continue + } + + // Check that A and EigenSym are equal as Matrix. + if !EqualApprox(s, &es, tol*float64(n)) { + t.Errorf("n=%d,cas=%d: A and EigenSym are not equal as Matrix", n, cas) + } + if !EqualApprox(s.T(), es.T(), tol*float64(n)) { + t.Errorf("n=%d,cas=%d: Aᵀ and EigenSymᵀ are not equal as Matrix", n, cas) } // Check that the eigenvectors are orthonormal. if !isOrthonormal(es.vectors, 1e-8) { - t.Errorf("Eigenvectors not orthonormal") + t.Errorf("n=%d,cas=%d: eigenvectors not orthonormal", n, cas) } // Check that the eigenvalues are actually eigenvalues. @@ -199,13 +210,13 @@ func TestSymEigen(t *testing.T) { scal.ScaleVec(es.values[i], v) if !EqualApprox(&m, &scal, 1e-8) { - t.Errorf("Eigenvalue does not match") + t.Errorf("n=%d,cas=%d: eigenvalue %d does not match", n, cas, i) } } // Check that the eigenvalues are in ascending order. if !sort.Float64sAreSorted(es.values) { - t.Errorf("Eigenvalues not ascending") + t.Errorf("n=%d,cas=%d: eigenvalues not ascending", n, cas) } } }