// Copyright ©2015 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mat import ( "testing" "golang.org/x/exp/rand" ) func TestSolve(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) // Hand-coded cases. for _, test := range []struct { a [][]float64 b [][]float64 ans [][]float64 shouldErr bool }{ { a: [][]float64{{6}}, b: [][]float64{{3}}, ans: [][]float64{{0.5}}, shouldErr: false, }, { a: [][]float64{ {1, 0, 0}, {0, 1, 0}, {0, 0, 1}, }, b: [][]float64{ {3}, {2}, {1}, }, ans: [][]float64{ {3}, {2}, {1}, }, shouldErr: false, }, { a: [][]float64{ {0.8147, 0.9134, 0.5528}, {0.9058, 0.6324, 0.8723}, {0.1270, 0.0975, 0.7612}, }, b: [][]float64{ {0.278}, {0.547}, {0.958}, }, ans: [][]float64{ {-0.932687281002860}, {0.303963920182067}, {1.375216503507109}, }, shouldErr: false, }, { a: [][]float64{ {0.8147, 0.9134, 0.5528}, {0.9058, 0.6324, 0.8723}, }, b: [][]float64{ {0.278}, {0.547}, }, ans: [][]float64{ {0.25919787248965376}, {-0.25560256266441034}, {0.5432324059702451}, }, shouldErr: false, }, { a: [][]float64{ {0.8147, 0.9134, 0.9}, {0.9058, 0.6324, 0.9}, {0.1270, 0.0975, 0.1}, {1.6, 2.8, -3.5}, }, b: [][]float64{ {0.278}, {0.547}, {-0.958}, {1.452}, }, ans: [][]float64{ {0.820970340787782}, {-0.218604626527306}, {-0.212938815234215}, }, shouldErr: false, }, { a: [][]float64{ {0.8147, 0.9134, 0.231, -1.65}, {0.9058, 0.6324, 0.9, 0.72}, {0.1270, 0.0975, 0.1, 1.723}, {1.6, 2.8, -3.5, 0.987}, {7.231, 9.154, 1.823, 0.9}, }, b: [][]float64{ {0.278, 8.635}, {0.547, 9.125}, {-0.958, -0.762}, {1.452, 1.444}, {1.999, -7.234}, }, ans: [][]float64{ {1.863006789511373, 44.467887791812750}, {-1.127270935407224, -34.073794226035126}, {-0.527926457947330, -8.032133759788573}, {-0.248621916204897, -2.366366415805275}, }, shouldErr: false, }, { a: [][]float64{ {0, 0}, {0, 0}, }, b: [][]float64{ {3}, {2}, }, ans: nil, shouldErr: true, }, { a: [][]float64{ {0, 0}, {0, 0}, {0, 0}, }, b: [][]float64{ {3}, {2}, {1}, }, ans: nil, shouldErr: true, }, { a: [][]float64{ {0, 0, 0}, {0, 0, 0}, }, b: [][]float64{ {3}, {2}, }, ans: nil, shouldErr: true, }, } { a := NewDense(flatten(test.a)) b := NewDense(flatten(test.b)) var ans *Dense if test.ans != nil { ans = NewDense(flatten(test.ans)) } var x Dense err := x.Solve(a, b) if err != nil { if !test.shouldErr { t.Errorf("Unexpected solve error: %s", err) } continue } if err == nil && test.shouldErr { t.Errorf("Did not error during solve.") continue } if !EqualApprox(&x, ans, 1e-12) { t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x) } } // Random Cases. for _, test := range []struct { m, n, bc int }{ {5, 5, 1}, {5, 10, 1}, {10, 5, 1}, {5, 5, 7}, {5, 10, 7}, {10, 5, 7}, {5, 5, 12}, {5, 10, 12}, {10, 5, 12}, } { m := test.m n := test.n bc := test.bc a := NewDense(m, n, nil) for i := 0; i < m; i++ { for j := 0; j < n; j++ { a.Set(i, j, rnd.Float64()) } } br := m b := NewDense(br, bc, nil) for i := 0; i < br; i++ { for j := 0; j < bc; j++ { b.Set(i, j, rnd.Float64()) } } var x Dense err := x.Solve(a, b) if err != nil { t.Errorf("unexpected error from dense solve: %v", err) } // Test that the normal equations hold. // Aᵀ * A * x = Aᵀ * b var tmp, lhs, rhs Dense tmp.Mul(a.T(), a) lhs.Mul(&tmp, &x) rhs.Mul(a.T(), b) if !EqualApprox(&lhs, &rhs, 1e-10) { t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) } } // Use testTwoInput. method := func(receiver, a, b Matrix) { type Solver interface { Solve(a, b Matrix) error } rd := receiver.(Solver) _ = rd.Solve(a, b) } denseComparison := func(receiver, a, b *Dense) { _ = receiver.Solve(a, b) } testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7) } func TestSolveVec(t *testing.T) { t.Parallel() rnd := rand.New(rand.NewSource(1)) for _, test := range []struct { m, n int }{ {5, 5}, {5, 10}, {10, 5}, {5, 5}, {5, 10}, {10, 5}, {5, 5}, {5, 10}, {10, 5}, } { m := test.m n := test.n a := NewDense(m, n, nil) for i := 0; i < m; i++ { for j := 0; j < n; j++ { a.Set(i, j, rnd.Float64()) } } br := m b := NewVecDense(br, nil) for i := 0; i < br; i++ { b.SetVec(i, rnd.Float64()) } var x VecDense err := x.SolveVec(a, b) if err != nil { t.Errorf("unexpected error from dense vector solve: %v", err) } // Test that the normal equations hold. // Aᵀ * A * x = Aᵀ * b var tmp, lhs, rhs Dense tmp.Mul(a.T(), a) lhs.Mul(&tmp, &x) rhs.Mul(a.T(), b) if !EqualApprox(&lhs, &rhs, 1e-10) { t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs) } } // Use testTwoInput method := func(receiver, a, b Matrix) { type SolveVecer interface { SolveVec(a Matrix, b Vector) error } rd := receiver.(SolveVecer) _ = rd.SolveVec(a, b.(Vector)) } denseComparison := func(receiver, a, b *Dense) { _ = receiver.Solve(a, b) } testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12) }