Files
FastDeploy/third_party/eigen/test/product_trmm.cpp
Jack Zhou 355382ad63 Move eigen to third party (#282)
* remove useless statement

* Add eigen to third_party dir

* remove reducdant lines
2022-09-26 19:24:02 +08:00

198 lines
8.1 KiB
C++

// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2008-2009 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
template <typename T>
int get_random_size() {
const int factor = NumTraits<T>::ReadCost;
const int max_test_size = EIGEN_TEST_MAX_SIZE > 2 * factor
? EIGEN_TEST_MAX_SIZE / factor
: EIGEN_TEST_MAX_SIZE;
return internal::random<int>(1, max_test_size);
}
template <typename Scalar, int Mode, int TriOrder, int OtherOrder, int ResOrder,
int OtherCols>
void trmm(int rows = get_random_size<Scalar>(),
int cols = get_random_size<Scalar>(),
int otherCols = OtherCols == Dynamic ? get_random_size<Scalar>()
: OtherCols) {
typedef Matrix<Scalar, Dynamic, Dynamic, TriOrder> TriMatrix;
typedef Matrix<Scalar, Dynamic, OtherCols,
OtherCols == 1 ? ColMajor : OtherOrder>
OnTheRight;
typedef Matrix<Scalar, OtherCols, Dynamic,
OtherCols == 1 ? RowMajor : OtherOrder>
OnTheLeft;
typedef Matrix<Scalar, Dynamic, OtherCols,
OtherCols == 1 ? ColMajor : ResOrder>
ResXS;
typedef Matrix<Scalar, OtherCols, Dynamic,
OtherCols == 1 ? RowMajor : ResOrder>
ResSX;
TriMatrix mat(rows, cols), tri(rows, cols), triTr(cols, rows),
s1tri(rows, cols), s1triTr(cols, rows);
OnTheRight ge_right(cols, otherCols);
OnTheLeft ge_left(otherCols, rows);
ResSX ge_sx, ge_sx_save;
ResXS ge_xs, ge_xs_save;
Scalar s1 = internal::random<Scalar>(), s2 = internal::random<Scalar>();
mat.setRandom();
tri = mat.template triangularView<Mode>();
triTr = mat.transpose().template triangularView<Mode>();
s1tri = (s1 * mat).template triangularView<Mode>();
s1triTr = (s1 * mat).transpose().template triangularView<Mode>();
ge_right.setRandom();
ge_left.setRandom();
VERIFY_IS_APPROX(ge_xs = mat.template triangularView<Mode>() * ge_right,
tri * ge_right);
VERIFY_IS_APPROX(ge_sx = ge_left * mat.template triangularView<Mode>(),
ge_left * tri);
VERIFY_IS_APPROX(
ge_xs.noalias() = mat.template triangularView<Mode>() * ge_right,
tri * ge_right);
VERIFY_IS_APPROX(
ge_sx.noalias() = ge_left * mat.template triangularView<Mode>(),
ge_left * tri);
if ((Mode & UnitDiag) == 0)
VERIFY_IS_APPROX(
ge_xs.noalias() = (s1 * mat.adjoint()).template triangularView<Mode>() *
(s2 * ge_left.transpose()),
s1 * triTr.conjugate() * (s2 * ge_left.transpose()));
VERIFY_IS_APPROX(
ge_xs.noalias() = (s1 * mat.transpose()).template triangularView<Mode>() *
(s2 * ge_left.transpose()),
s1triTr * (s2 * ge_left.transpose()));
VERIFY_IS_APPROX(ge_sx.noalias() = (s2 * ge_left) *
(s1 * mat).template triangularView<Mode>(),
(s2 * ge_left) * s1tri);
VERIFY_IS_APPROX(
ge_sx.noalias() =
ge_right.transpose() * mat.adjoint().template triangularView<Mode>(),
ge_right.transpose() * triTr.conjugate());
VERIFY_IS_APPROX(
ge_sx.noalias() =
ge_right.adjoint() * mat.adjoint().template triangularView<Mode>(),
ge_right.adjoint() * triTr.conjugate());
ge_xs_save = ge_xs;
if ((Mode & UnitDiag) == 0)
VERIFY_IS_APPROX(
(ge_xs_save + s1 * triTr.conjugate() * (s2 * ge_left.adjoint())).eval(),
ge_xs.noalias() +=
(s1 * mat.adjoint()).template triangularView<Mode>() *
(s2 * ge_left.adjoint()));
ge_xs_save = ge_xs;
VERIFY_IS_APPROX((ge_xs_save + s1triTr * (s2 * ge_left.adjoint())).eval(),
ge_xs.noalias() +=
(s1 * mat.transpose()).template triangularView<Mode>() *
(s2 * ge_left.adjoint()));
ge_sx.setRandom();
ge_sx_save = ge_sx;
if ((Mode & UnitDiag) == 0)
VERIFY_IS_APPROX(
ge_sx_save - (ge_right.adjoint() * (-s1 * triTr).conjugate()).eval(),
ge_sx.noalias() -=
(ge_right.adjoint() *
(-s1 * mat).adjoint().template triangularView<Mode>())
.eval());
if ((Mode & UnitDiag) == 0)
VERIFY_IS_APPROX(
ge_xs = (s1 * mat).adjoint().template triangularView<Mode>() *
ge_left.adjoint(),
numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
VERIFY_IS_APPROX(
ge_xs = (s1 * mat).transpose().template triangularView<Mode>() *
ge_left.adjoint(),
s1triTr * ge_left.adjoint());
// TODO check with sub-matrix expressions ?
// destination with a non-default inner-stride
// see bug 1741
{
VERIFY_IS_APPROX(
ge_xs.noalias() = mat.template triangularView<Mode>() * ge_right,
tri * ge_right);
typedef Matrix<Scalar, Dynamic, Dynamic> MatrixX;
MatrixX buffer(2 * ge_xs.rows(), 2 * ge_xs.cols());
Map<ResXS, 0, Stride<Dynamic, 2> > map1(
buffer.data(), ge_xs.rows(), ge_xs.cols(),
Stride<Dynamic, 2>(2 * ge_xs.outerStride(), 2));
buffer.setZero();
VERIFY_IS_APPROX(
map1.noalias() = mat.template triangularView<Mode>() * ge_right,
tri * ge_right);
}
}
template <typename Scalar, int Mode, int TriOrder>
void trmv(int rows = get_random_size<Scalar>(),
int cols = get_random_size<Scalar>()) {
trmm<Scalar, Mode, TriOrder, ColMajor, ColMajor, 1>(rows, cols, 1);
}
template <typename Scalar, int Mode, int TriOrder, int OtherOrder, int ResOrder>
void trmm(int rows = get_random_size<Scalar>(),
int cols = get_random_size<Scalar>(),
int otherCols = get_random_size<Scalar>()) {
trmm<Scalar, Mode, TriOrder, OtherOrder, ResOrder, Dynamic>(rows, cols,
otherCols);
}
#define CALL_ALL_ORDERS(NB, SCALAR, MODE) \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, ColMajor, ColMajor, ColMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, ColMajor, ColMajor, RowMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, ColMajor, RowMajor, ColMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, ColMajor, RowMajor, RowMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, RowMajor, ColMajor, ColMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, RowMajor, ColMajor, RowMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, RowMajor, RowMajor, ColMajor>())); \
EIGEN_CAT(CALL_SUBTEST_, NB) \
((trmm<SCALAR, MODE, RowMajor, RowMajor, RowMajor>())); \
\
EIGEN_CAT(CALL_SUBTEST_1, NB)((trmv<SCALAR, MODE, ColMajor>())); \
EIGEN_CAT(CALL_SUBTEST_1, NB)((trmv<SCALAR, MODE, RowMajor>()));
#define CALL_ALL(NB, SCALAR) \
CALL_ALL_ORDERS(EIGEN_CAT(1, NB), SCALAR, Upper) \
CALL_ALL_ORDERS(EIGEN_CAT(2, NB), SCALAR, UnitUpper) \
CALL_ALL_ORDERS(EIGEN_CAT(3, NB), SCALAR, StrictlyUpper) \
CALL_ALL_ORDERS(EIGEN_CAT(1, NB), SCALAR, Lower) \
CALL_ALL_ORDERS(EIGEN_CAT(2, NB), SCALAR, UnitLower) \
CALL_ALL_ORDERS(EIGEN_CAT(3, NB), SCALAR, StrictlyLower)
EIGEN_DECLARE_TEST(product_trmm) {
for (int i = 0; i < g_repeat; i++) {
CALL_ALL(1, float); // EIGEN_SUFFIXES;11;111;21;121;31;131
CALL_ALL(2, double); // EIGEN_SUFFIXES;12;112;22;122;32;132
CALL_ALL(3, std::complex<float>); // EIGEN_SUFFIXES;13;113;23;123;33;133
CALL_ALL(4, std::complex<double>); // EIGEN_SUFFIXES;14;114;24;124;34;134
}
}