diff --git a/csrc/fastdeploy/function/eigen.h b/csrc/fastdeploy/function/eigen.h index 32bacf064..ac2902e3d 100644 --- a/csrc/fastdeploy/function/eigen.h +++ b/csrc/fastdeploy/function/eigen.h @@ -18,6 +18,7 @@ #include #include #include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/utils/axis_utils.h" #include "unsupported/Eigen/CXX11/Tensor" namespace fastdeploy { @@ -96,6 +97,30 @@ struct EigenVector : public EigenTensor { } }; +template +struct EigenMatrix : public EigenTensor { + static typename EigenMatrix::Type Reshape(FDTensor& tensor, // NOLINT + int num_col_dims) { + int rank = tensor.shape.size(); + FDASSERT((num_col_dims > 0 && num_col_dims < rank), + "Input dimension number(num_col_dims)."); + const int n = SizeToAxis(num_col_dims, tensor.shape); + const int d = SizeFromAxis(num_col_dims, tensor.shape); + return EigenMatrix::From(tensor, {n, d}); + } + + static typename EigenMatrix::ConstType Reshape(const FDTensor& tensor, + int num_col_dims) { + int rank = tensor.shape.size(); + FDASSERT((num_col_dims > 0 && num_col_dims < rank), + "Input dimension number(num_col_dims)."); + const int n = SizeToAxis(num_col_dims, tensor.shape); + const int d = SizeFromAxis(num_col_dims, tensor.shape); + return EigenMatrix::From(tensor, {n, d}); + } +}; + class EigenDeviceWrapper { public: static std::shared_ptr GetInstance(); diff --git a/csrc/fastdeploy/function/softmax.cc b/csrc/fastdeploy/function/softmax.cc new file mode 100644 index 000000000..2eb85e52c --- /dev/null +++ b/csrc/fastdeploy/function/softmax.cc @@ -0,0 +1,123 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/function/softmax.h" + +#include + +#include "fastdeploy/function/eigen.h" +#include "fastdeploy/utils/axis_utils.h" +#include "fastdeploy/utils/utils.h" + +namespace fastdeploy { +#ifdef ENABLE_FDTENSOR_FUNC + +template +struct ValueClip { + T operator()(const T& x) const { + const T kThreshold = static_cast(-64.); + return x < kThreshold ? kThreshold : x; + } +}; + +template +struct SoftmaxEigen { + void operator()(const FDTensor& x, FDTensor* out, int axis_dim) { + constexpr int kBatchDim = 0; + constexpr int kClassDim = 1; + constexpr int kAxisDim = 1; + + auto logits = EigenMatrix::From(x); + auto softmax = EigenMatrix::From(*out); + + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); + const int num_remain = num_classes / axis_dim; + Eigen::DSizes along_axis(kAxisDim); + Eigen::DSizes batch_classes(batch_size, num_classes); + Eigen::DSizes batch_by_one(batch_size, 1); + Eigen::DSizes one_by_class(1, num_classes); + Eigen::DSizes batch_one_remain(batch_size, 1, num_remain); + Eigen::DSizes one_axis_one(1, axis_dim, 1); + Eigen::DSizes one_axis(1, axis_dim); + Eigen::DSizes batch_axis_remain(batch_size, axis_dim, num_remain); + + const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice(); + // For numerical stability, logits should be shifted by maximum number along + // axis, calculate shifted_logits into softmax tensor for memory reuse. + if (num_remain == 1) { + // axis == -1, axis and class in same dimension, calculate along + // class dimension directly for higher performance + softmax.device(dev) = (logits - + logits.maximum(along_axis) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)) + .unaryExpr(ValueClip()); + } else { + // axis != -1, class dimension split into (axis, remain), max and sum + // should be calculated along axis dimension + softmax.device(dev) = (logits.reshape(batch_axis_remain) - + logits.reshape(batch_axis_remain) + .maximum(along_axis) + .eval() + .reshape(batch_one_remain) + .broadcast(one_axis_one) + .reshape(batch_axis_remain)) + .reshape(batch_classes) + .unaryExpr(ValueClip()); + } + softmax.device(dev) = softmax.exp(); + softmax.device(dev) = (softmax * + softmax.reshape(batch_axis_remain) + .sum(along_axis) + .inverse() + .eval() + .broadcast(one_axis)); + } +}; + +template +void SoftmaxFunctor(const FDTensor& x, FDTensor* out, int axis) { + SoftmaxEigen()(x, out, axis); +} + +template +void SoftmaxKernel(const FDTensor& x, FDTensor* out, int axis) { + const int rank = x.shape.size(); + const int calc_axis = CanonicalAxis(axis, rank); + int axis_dim = x.shape[calc_axis]; + out->Allocate(x.shape, x.dtype); + if (out->Numel() == 0) { + return; + } + const int n = SizeToAxis(calc_axis, x.shape); + const int d = SizeFromAxis(calc_axis, x.shape); + // Reshape to 2d tensor + + FDTensor x_2d, out_2d; + x_2d.SetExternalData({n, d}, x.dtype, const_cast(x.Data())); + out_2d.SetExternalData({n, d}, out->dtype, out->Data()); + + SoftmaxFunctor(x_2d, &out_2d, axis_dim); +} + +void Softmax(const FDTensor& x, FDTensor* out, int axis) { + FDASSERT(std::abs(axis) < x.shape.size(), + "The given axis should be smaller than the input's dimension"); + FD_VISIT_FLOAT_TYPES(x.dtype, "SoftmaxKernel", + ([&] { SoftmaxKernel(x, out, axis); })); +} +#endif +} // namespace fastdeploy diff --git a/csrc/fastdeploy/function/softmax.h b/csrc/fastdeploy/function/softmax.h new file mode 100644 index 000000000..869fb7460 --- /dev/null +++ b/csrc/fastdeploy/function/softmax.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/core/fd_tensor.h" + +namespace fastdeploy { + +#ifdef ENABLE_FDTENSOR_FUNC +/** Excute the softmax operation for input FDTensor along given dims. + @param x The input tensor. + @param out The output tensor which stores the result. + @param axis The axis to be computed softmax value. +*/ +FASTDEPLOY_DECL void Softmax(const FDTensor& x, FDTensor* out, int axis = -1); + +#endif +} // namespace fastdeploy diff --git a/csrc/fastdeploy/utils/axis_utils.h b/csrc/fastdeploy/utils/axis_utils.h new file mode 100644 index 000000000..5d82a49fd --- /dev/null +++ b/csrc/fastdeploy/utils/axis_utils.h @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace fastdeploy { + +static inline int CanonicalAxis(const int axis, const int rank) { + if (axis < 0) { + return axis + rank; + } + return axis; +} + +static inline int SizeToAxis(const int axis, const std::vector& dims) { + int size = 1; + for (int i = 0; i < axis; i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeFromAxis(const int axis, + const std::vector& dims) { + int size = 1; + for (int i = axis; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +static inline int SizeOutAxis(const int axis, + const std::vector& dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + +} // namespace fastdeploy diff --git a/csrc/fastdeploy/utils/utils.h b/csrc/fastdeploy/utils/utils.h index 84d2ee74c..b4f626d7b 100644 --- a/csrc/fastdeploy/utils/utils.h +++ b/csrc/fastdeploy/utils/utils.h @@ -103,36 +103,40 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, #define FD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ FD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) -#define FD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ - __VA_ARGS__) \ - default: \ - FDASSERT(false, "Invalid enum data type.") \ - } \ +// Visit different data type to match the corresponding function of FDTensor +#define FD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ + __VA_ARGS__) \ + default: \ + FDASSERT(false, \ + "Invalid enum data type. Only accept data type BOOL, INT32, " \ + "INT64, FP32, FP64.") \ + } \ }() -#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ - __VA_ARGS__) \ - FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ - __VA_ARGS__) \ - default: \ - FDASSERT(false, "Invalid enum data type.") \ - } \ +#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \ + __VA_ARGS__) \ + FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \ + __VA_ARGS__) \ + default: \ + FDASSERT(false, \ + "Invalid enum data type. Only accept data type FP32, FP64.") \ + } \ }() #define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \ @@ -144,7 +148,9 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file, FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \ __VA_ARGS__) \ default: \ - FDASSERT(false, "Invalid enum data type.") \ + FDASSERT( \ + false, \ + "Invalid enum data type. Only accept data type INT32, INT64.") \ } \ }() diff --git a/docs/api/function.md b/docs/api/function.md index e62ade45c..ee6b314a8 100644 --- a/docs/api/function.md +++ b/docs/api/function.md @@ -1,6 +1,6 @@ # FDTensor C++ 张量化函数 -FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为三类函数:Reduce类函数,Manipulate类函数,Elementwise类函数。 +FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为三类函数:Reduce类函数,Manipulate类函数,Math类函数以及Elementwise类函数。 ## Reduce类函数 @@ -239,6 +239,39 @@ input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data()); Transpose(input, &output, {1, 0}); ``` +## Math类函数 + +目前FastDeploy支持1种Math类函数:Softmax。 + +### Softmax + +#### 函数签名 + +```c++ +/** Excute the softmax operation for input FDTensor along given dims. + @param x The input tensor. + @param out The output tensor which stores the result. + @param axis The axis to be computed softmax value. +*/ +void Softmax(const FDTensor& x, FDTensor* out, int axis = -1); +``` + +#### 使用示例 + +```c++ +FDTensor input, output; +CheckShape check_shape; +CheckData check_data; +std::vector inputs = {1, 2, 3, 4, 5, 6}; +input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data()); + +// Transpose the input tensor with axis {1, 0}. +// The output result would be +// [[0.04742587, 0.04742587, 0.04742587], +// [0.95257413, 0.95257413, 0.95257413]] +Softmax(input, &output, 0); +``` + ## Elementwise类函数 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a1cbf0ca5..1a1d35fb4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -60,11 +60,12 @@ function(add_fastdeploy_unittest CC_FILE) endfunction() if(WITH_TESTING) + include_directories(${CMAKE_CURRENT_SOURCE_DIR}) add_library(fastdeploy_gtest_main STATIC gtest_main) target_link_libraries(fastdeploy_gtest_main PUBLIC gtest gflags) message(STATUS "") message(STATUS "*************FastDeploy Unittest Summary**********") - file(GLOB ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/test_*.cc) + file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc) foreach(_CC_FILE ${ALL_TEST_SRCS}) add_fastdeploy_unittest(${_CC_FILE}) endforeach() diff --git a/tests/test_reduce.cc b/tests/function/test_reduce.cc similarity index 100% rename from tests/test_reduce.cc rename to tests/function/test_reduce.cc diff --git a/tests/function/test_softmax.cc b/tests/function/test_softmax.cc new file mode 100644 index 000000000..0a2d60a56 --- /dev/null +++ b/tests/function/test_softmax.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/function/softmax.h" +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "gtest_utils.h" + +namespace fastdeploy { +#ifdef ENABLE_FDTENSOR_FUNC +TEST(fastdeploy, softmax) { + FDTensor input, output; + CheckShape check_shape; + CheckData check_data; + std::vector inputs = {1, 2, 3, 4, 5, 6}; + std::vector expected_result_axis0 = { + 0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413}; + std::vector expected_result_axis1 = { + 0.09003057, 0.24472846, 0.66524088, 0.09003057, 0.24472846, 0.66524088}; + input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data()); + + // axis = 0 + Softmax(input, &output, 0); + check_shape(output.shape, {2, 3}); + check_data(reinterpret_cast(output.Data()), + expected_result_axis0.data(), expected_result_axis0.size()); + + // axis = 1 + Softmax(input, &output, 1); + check_shape(output.shape, {2, 3}); + check_data(reinterpret_cast(output.Data()), + expected_result_axis1.data(), expected_result_axis1.size()); +} +#endif +} // namespace fastdeploy \ No newline at end of file diff --git a/tests/test_transpose.cc b/tests/function/test_transpose.cc similarity index 100% rename from tests/test_transpose.cc rename to tests/function/test_transpose.cc