diff --git a/fastdeploy/function/math.cc b/fastdeploy/function/math.cc index 292b1ae15..3889ca698 100644 --- a/fastdeploy/function/math.cc +++ b/fastdeploy/function/math.cc @@ -39,6 +39,7 @@ DEFINE_ACTIVATION_KERNEL(Sqrt, SqrtFunctor) DEFINE_ACTIVATION_KERNEL(Log, LogFunctor) DEFINE_ACTIVATION_KERNEL(Round, RoundFunctor) DEFINE_ACTIVATION_KERNEL(Exp, ExpFunctor) +DEFINE_ACTIVATION_KERNEL(Abs, AbsFunctor) void Sqrt(const FDTensor& x, FDTensor* out) { FD_VISIT_FLOAT_TYPES(x.dtype, "SqrtKernel", @@ -60,5 +61,10 @@ void Exp(const FDTensor& x, FDTensor* out) { ([&] { ExpKernel(x, out); })); } +void Abs(const FDTensor& x, FDTensor* out) { + FD_VISIT_FLOAT_TYPES(x.dtype, "AbsKernel", + ([&] { AbsKernel(x, out); })); +} + } // namespace function } // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/function/math.h b/fastdeploy/function/math.h index 3dd93c818..14ac79c64 100644 --- a/fastdeploy/function/math.h +++ b/fastdeploy/function/math.h @@ -43,5 +43,11 @@ FASTDEPLOY_DECL void Round(const FDTensor& x, FDTensor* out); */ FASTDEPLOY_DECL void Exp(const FDTensor& x, FDTensor* out); +/** This operator is used to perform elementwise abs for input X. Only for float type FDTensor + @param x The input tensor. + @param out The output tensor which stores the result. +*/ +FASTDEPLOY_DECL void Abs(const FDTensor& x, FDTensor* out); + } // namespace function } // namespace fastdeploy diff --git a/fastdeploy/function/math_functor.h b/fastdeploy/function/math_functor.h index 440ce94a9..f82224c97 100644 --- a/fastdeploy/function/math_functor.h +++ b/fastdeploy/function/math_functor.h @@ -52,5 +52,14 @@ template struct SqrtFunctor { } }; +// abs(x) = x if x > 0 else -x +template struct AbsFunctor { + template + void operator()(Device d, X x, Out out) const { + out.device(d) = + x.unaryExpr([](T v) { return v > static_cast(0) ? v : -v; }); + } +}; + } // namespace function } // namespace fastdeploy diff --git a/tests/function/test_math.cc b/tests/function/test_math.cc index ec09fa7ef..8027012e9 100644 --- a/tests/function/test_math.cc +++ b/tests/function/test_math.cc @@ -83,5 +83,18 @@ TEST(fastdeploy, exp_sqrt_round_log) { log_result.size()); } +TEST(fastdeploy, abs) { + CheckShape check_shape; + CheckData check_data; + FDTensor x, y; + std::vector test_data = {-1, 2, 3, -5, -4, -6}; + x.SetExternalData({2, 3}, FDDataType::FP32, test_data.data()); + std::vector result = {1, 2, 3, 5, 4, 6}; + Abs(x, &y); + check_shape(y.shape, {2, 3}); + check_data(reinterpret_cast(y.Data()), result.data(), + result.size()); +} + } // namespace function } // namespace fastdeploy \ No newline at end of file