Add softmax function (#93)

* Add softmax function

* Add softmax unittest

* Add Softmax docs

* Add function directory

* Add comment for FD_VISIT_ALL_TYPES macro
This commit is contained in:
Jack Zhou
2022-08-11 12:01:19 +08:00
committed by GitHub
parent 4b67b6e8f9
commit cec0d502e0
10 changed files with 349 additions and 31 deletions

View File

@@ -18,6 +18,7 @@
#include <memory>
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/axis_utils.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace fastdeploy {
@@ -96,6 +97,30 @@ struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
}
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(FDTensor& tensor, // NOLINT
int num_col_dims) {
int rank = tensor.shape.size();
FDASSERT((num_col_dims > 0 && num_col_dims < rank),
"Input dimension number(num_col_dims).");
const int n = SizeToAxis(num_col_dims, tensor.shape);
const int d = SizeFromAxis(num_col_dims, tensor.shape);
return EigenMatrix::From(tensor, {n, d});
}
static typename EigenMatrix::ConstType Reshape(const FDTensor& tensor,
int num_col_dims) {
int rank = tensor.shape.size();
FDASSERT((num_col_dims > 0 && num_col_dims < rank),
"Input dimension number(num_col_dims).");
const int n = SizeToAxis(num_col_dims, tensor.shape);
const int d = SizeFromAxis(num_col_dims, tensor.shape);
return EigenMatrix::From(tensor, {n, d});
}
};
class EigenDeviceWrapper {
public:
static std::shared_ptr<EigenDeviceWrapper> GetInstance();

View File

@@ -0,0 +1,123 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/function/softmax.h"
#include <cstdlib>
#include "fastdeploy/function/eigen.h"
#include "fastdeploy/utils/axis_utils.h"
#include "fastdeploy/utils/utils.h"
namespace fastdeploy {
#ifdef ENABLE_FDTENSOR_FUNC
template <typename T>
struct ValueClip {
T operator()(const T& x) const {
const T kThreshold = static_cast<T>(-64.);
return x < kThreshold ? kThreshold : x;
}
};
template <typename T>
struct SoftmaxEigen {
void operator()(const FDTensor& x, FDTensor* out, int axis_dim) {
constexpr int kBatchDim = 0;
constexpr int kClassDim = 1;
constexpr int kAxisDim = 1;
auto logits = EigenMatrix<T>::From(x);
auto softmax = EigenMatrix<T>::From(*out);
const int batch_size = logits.dimension(kBatchDim);
const int num_classes = logits.dimension(kClassDim);
const int num_remain = num_classes / axis_dim;
Eigen::DSizes<int, 1> along_axis(kAxisDim);
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
// For numerical stability, logits should be shifted by maximum number along
// axis, calculate shifted_logits into softmax tensor for memory reuse.
if (num_remain == 1) {
// axis == -1, axis and class in same dimension, calculate along
// class dimension directly for higher performance
softmax.device(dev) = (logits -
logits.maximum(along_axis)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
} else {
// axis != -1, class dimension split into (axis, remain), max and sum
// should be calculated along axis dimension
softmax.device(dev) = (logits.reshape(batch_axis_remain) -
logits.reshape(batch_axis_remain)
.maximum(along_axis)
.eval()
.reshape(batch_one_remain)
.broadcast(one_axis_one)
.reshape(batch_axis_remain))
.reshape(batch_classes)
.unaryExpr(ValueClip<T>());
}
softmax.device(dev) = softmax.exp();
softmax.device(dev) = (softmax *
softmax.reshape(batch_axis_remain)
.sum(along_axis)
.inverse()
.eval()
.broadcast(one_axis));
}
};
template <typename T>
void SoftmaxFunctor(const FDTensor& x, FDTensor* out, int axis) {
SoftmaxEigen<T>()(x, out, axis);
}
template <typename T>
void SoftmaxKernel(const FDTensor& x, FDTensor* out, int axis) {
const int rank = x.shape.size();
const int calc_axis = CanonicalAxis(axis, rank);
int axis_dim = x.shape[calc_axis];
out->Allocate(x.shape, x.dtype);
if (out->Numel() == 0) {
return;
}
const int n = SizeToAxis(calc_axis, x.shape);
const int d = SizeFromAxis(calc_axis, x.shape);
// Reshape to 2d tensor
FDTensor x_2d, out_2d;
x_2d.SetExternalData({n, d}, x.dtype, const_cast<void*>(x.Data()));
out_2d.SetExternalData({n, d}, out->dtype, out->Data());
SoftmaxFunctor<T>(x_2d, &out_2d, axis_dim);
}
void Softmax(const FDTensor& x, FDTensor* out, int axis) {
FDASSERT(std::abs(axis) < x.shape.size(),
"The given axis should be smaller than the input's dimension");
FD_VISIT_FLOAT_TYPES(x.dtype, "SoftmaxKernel",
([&] { SoftmaxKernel<data_t>(x, out, axis); }));
}
#endif
} // namespace fastdeploy

View File

@@ -0,0 +1,30 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/fd_tensor.h"
namespace fastdeploy {
#ifdef ENABLE_FDTENSOR_FUNC
/** Excute the softmax operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis to be computed softmax value.
*/
FASTDEPLOY_DECL void Softmax(const FDTensor& x, FDTensor* out, int axis = -1);
#endif
} // namespace fastdeploy

View File

@@ -0,0 +1,52 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
static inline int CanonicalAxis(const int axis, const int rank) {
if (axis < 0) {
return axis + rank;
}
return axis;
}
static inline int SizeToAxis(const int axis, const std::vector<int64_t>& dims) {
int size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis,
const std::vector<int64_t>& dims) {
int size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
static inline int SizeOutAxis(const int axis,
const std::vector<int64_t>& dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
} // namespace fastdeploy

View File

@@ -103,6 +103,7 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
#define FD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
FD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
// Visit different data type to match the corresponding function of FDTensor
#define FD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
@@ -118,7 +119,9 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
__VA_ARGS__) \
default: \
FDASSERT(false, "Invalid enum data type.") \
FDASSERT(false, \
"Invalid enum data type. Only accept data type BOOL, INT32, " \
"INT64, FP32, FP64.") \
} \
}()
@@ -131,7 +134,8 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
__VA_ARGS__) \
default: \
FDASSERT(false, "Invalid enum data type.") \
FDASSERT(false, \
"Invalid enum data type. Only accept data type FP32, FP64.") \
} \
}()
@@ -144,7 +148,9 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
__VA_ARGS__) \
default: \
FDASSERT(false, "Invalid enum data type.") \
FDASSERT( \
false, \
"Invalid enum data type. Only accept data type INT32, INT64.") \
} \
}()

View File

@@ -1,6 +1,6 @@
# FDTensor C++ 张量化函数
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中我们往往需要对输入输出的数据进行一些数据处理用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现但开发难度会比较大如对3维Tensor的第2维求最大值。针对这个问题FastDeploy基于FDTensor开发了一套C++张量化函数用于降低FastDeploy用户的开发成本提高开发效率。目前主要分为三类函数Reduce类函数Manipulate类函数Elementwise类函数。
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中我们往往需要对输入输出的数据进行一些数据处理用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现但开发难度会比较大如对3维Tensor的第2维求最大值。针对这个问题FastDeploy基于FDTensor开发了一套C++张量化函数用于降低FastDeploy用户的开发成本提高开发效率。目前主要分为三类函数Reduce类函数Manipulate类函数Math类函数以及Elementwise类函数。
## Reduce类函数
@@ -239,6 +239,39 @@ input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
Transpose(input, &output, {1, 0});
```
## Math类函数
目前FastDeploy支持1种Math类函数Softmax。
### Softmax
#### 函数签名
```c++
/** Excute the softmax operation for input FDTensor along given dims.
@param x The input tensor.
@param out The output tensor which stores the result.
@param axis The axis to be computed softmax value.
*/
void Softmax(const FDTensor& x, FDTensor* out, int axis = -1);
```
#### 使用示例
```c++
FDTensor input, output;
CheckShape check_shape;
CheckData check_data;
std::vector<float> inputs = {1, 2, 3, 4, 5, 6};
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
// Transpose the input tensor with axis {1, 0}.
// The output result would be
// [[0.04742587, 0.04742587, 0.04742587],
// [0.95257413, 0.95257413, 0.95257413]]
Softmax(input, &output, 0);
```
## Elementwise类函数

View File

@@ -60,11 +60,12 @@ function(add_fastdeploy_unittest CC_FILE)
endfunction()
if(WITH_TESTING)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_library(fastdeploy_gtest_main STATIC gtest_main)
target_link_libraries(fastdeploy_gtest_main PUBLIC gtest gflags)
message(STATUS "")
message(STATUS "*************FastDeploy Unittest Summary**********")
file(GLOB ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/test_*.cc)
file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc)
foreach(_CC_FILE ${ALL_TEST_SRCS})
add_fastdeploy_unittest(${_CC_FILE})
endforeach()

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/function/softmax.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "gtest_utils.h"
namespace fastdeploy {
#ifdef ENABLE_FDTENSOR_FUNC
TEST(fastdeploy, softmax) {
FDTensor input, output;
CheckShape check_shape;
CheckData check_data;
std::vector<float> inputs = {1, 2, 3, 4, 5, 6};
std::vector<float> expected_result_axis0 = {
0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413};
std::vector<float> expected_result_axis1 = {
0.09003057, 0.24472846, 0.66524088, 0.09003057, 0.24472846, 0.66524088};
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
// axis = 0
Softmax(input, &output, 0);
check_shape(output.shape, {2, 3});
check_data(reinterpret_cast<const float*>(output.Data()),
expected_result_axis0.data(), expected_result_axis0.size());
// axis = 1
Softmax(input, &output, 1);
check_shape(output.shape, {2, 3});
check_data(reinterpret_cast<const float*>(output.Data()),
expected_result_axis1.data(), expected_result_axis1.size());
}
#endif
} // namespace fastdeploy