mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
Add softmax function (#93)
* Add softmax function * Add softmax unittest * Add Softmax docs * Add function directory * Add comment for FD_VISIT_ALL_TYPES macro
This commit is contained in:
@@ -18,6 +18,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
|
#include "fastdeploy/utils/axis_utils.h"
|
||||||
#include "unsupported/Eigen/CXX11/Tensor"
|
#include "unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
@@ -96,6 +97,30 @@ struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T, int MajorType = Eigen::RowMajor,
|
||||||
|
typename IndexType = Eigen::DenseIndex>
|
||||||
|
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
|
||||||
|
static typename EigenMatrix::Type Reshape(FDTensor& tensor, // NOLINT
|
||||||
|
int num_col_dims) {
|
||||||
|
int rank = tensor.shape.size();
|
||||||
|
FDASSERT((num_col_dims > 0 && num_col_dims < rank),
|
||||||
|
"Input dimension number(num_col_dims).");
|
||||||
|
const int n = SizeToAxis(num_col_dims, tensor.shape);
|
||||||
|
const int d = SizeFromAxis(num_col_dims, tensor.shape);
|
||||||
|
return EigenMatrix::From(tensor, {n, d});
|
||||||
|
}
|
||||||
|
|
||||||
|
static typename EigenMatrix::ConstType Reshape(const FDTensor& tensor,
|
||||||
|
int num_col_dims) {
|
||||||
|
int rank = tensor.shape.size();
|
||||||
|
FDASSERT((num_col_dims > 0 && num_col_dims < rank),
|
||||||
|
"Input dimension number(num_col_dims).");
|
||||||
|
const int n = SizeToAxis(num_col_dims, tensor.shape);
|
||||||
|
const int d = SizeFromAxis(num_col_dims, tensor.shape);
|
||||||
|
return EigenMatrix::From(tensor, {n, d});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class EigenDeviceWrapper {
|
class EigenDeviceWrapper {
|
||||||
public:
|
public:
|
||||||
static std::shared_ptr<EigenDeviceWrapper> GetInstance();
|
static std::shared_ptr<EigenDeviceWrapper> GetInstance();
|
||||||
|
123
csrc/fastdeploy/function/softmax.cc
Normal file
123
csrc/fastdeploy/function/softmax.cc
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "fastdeploy/function/softmax.h"
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
|
||||||
|
#include "fastdeploy/function/eigen.h"
|
||||||
|
#include "fastdeploy/utils/axis_utils.h"
|
||||||
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
#ifdef ENABLE_FDTENSOR_FUNC
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ValueClip {
|
||||||
|
T operator()(const T& x) const {
|
||||||
|
const T kThreshold = static_cast<T>(-64.);
|
||||||
|
return x < kThreshold ? kThreshold : x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct SoftmaxEigen {
|
||||||
|
void operator()(const FDTensor& x, FDTensor* out, int axis_dim) {
|
||||||
|
constexpr int kBatchDim = 0;
|
||||||
|
constexpr int kClassDim = 1;
|
||||||
|
constexpr int kAxisDim = 1;
|
||||||
|
|
||||||
|
auto logits = EigenMatrix<T>::From(x);
|
||||||
|
auto softmax = EigenMatrix<T>::From(*out);
|
||||||
|
|
||||||
|
const int batch_size = logits.dimension(kBatchDim);
|
||||||
|
const int num_classes = logits.dimension(kClassDim);
|
||||||
|
const int num_remain = num_classes / axis_dim;
|
||||||
|
Eigen::DSizes<int, 1> along_axis(kAxisDim);
|
||||||
|
Eigen::DSizes<int, 2> batch_classes(batch_size, num_classes);
|
||||||
|
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
|
||||||
|
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
|
||||||
|
Eigen::DSizes<int, 3> batch_one_remain(batch_size, 1, num_remain);
|
||||||
|
Eigen::DSizes<int, 3> one_axis_one(1, axis_dim, 1);
|
||||||
|
Eigen::DSizes<int, 2> one_axis(1, axis_dim);
|
||||||
|
Eigen::DSizes<int, 3> batch_axis_remain(batch_size, axis_dim, num_remain);
|
||||||
|
|
||||||
|
const auto& dev = *EigenDeviceWrapper::GetInstance()->GetDevice();
|
||||||
|
// For numerical stability, logits should be shifted by maximum number along
|
||||||
|
// axis, calculate shifted_logits into softmax tensor for memory reuse.
|
||||||
|
if (num_remain == 1) {
|
||||||
|
// axis == -1, axis and class in same dimension, calculate along
|
||||||
|
// class dimension directly for higher performance
|
||||||
|
softmax.device(dev) = (logits -
|
||||||
|
logits.maximum(along_axis)
|
||||||
|
.eval()
|
||||||
|
.reshape(batch_by_one)
|
||||||
|
.broadcast(one_by_class))
|
||||||
|
.unaryExpr(ValueClip<T>());
|
||||||
|
} else {
|
||||||
|
// axis != -1, class dimension split into (axis, remain), max and sum
|
||||||
|
// should be calculated along axis dimension
|
||||||
|
softmax.device(dev) = (logits.reshape(batch_axis_remain) -
|
||||||
|
logits.reshape(batch_axis_remain)
|
||||||
|
.maximum(along_axis)
|
||||||
|
.eval()
|
||||||
|
.reshape(batch_one_remain)
|
||||||
|
.broadcast(one_axis_one)
|
||||||
|
.reshape(batch_axis_remain))
|
||||||
|
.reshape(batch_classes)
|
||||||
|
.unaryExpr(ValueClip<T>());
|
||||||
|
}
|
||||||
|
softmax.device(dev) = softmax.exp();
|
||||||
|
softmax.device(dev) = (softmax *
|
||||||
|
softmax.reshape(batch_axis_remain)
|
||||||
|
.sum(along_axis)
|
||||||
|
.inverse()
|
||||||
|
.eval()
|
||||||
|
.broadcast(one_axis));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SoftmaxFunctor(const FDTensor& x, FDTensor* out, int axis) {
|
||||||
|
SoftmaxEigen<T>()(x, out, axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void SoftmaxKernel(const FDTensor& x, FDTensor* out, int axis) {
|
||||||
|
const int rank = x.shape.size();
|
||||||
|
const int calc_axis = CanonicalAxis(axis, rank);
|
||||||
|
int axis_dim = x.shape[calc_axis];
|
||||||
|
out->Allocate(x.shape, x.dtype);
|
||||||
|
if (out->Numel() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int n = SizeToAxis(calc_axis, x.shape);
|
||||||
|
const int d = SizeFromAxis(calc_axis, x.shape);
|
||||||
|
// Reshape to 2d tensor
|
||||||
|
|
||||||
|
FDTensor x_2d, out_2d;
|
||||||
|
x_2d.SetExternalData({n, d}, x.dtype, const_cast<void*>(x.Data()));
|
||||||
|
out_2d.SetExternalData({n, d}, out->dtype, out->Data());
|
||||||
|
|
||||||
|
SoftmaxFunctor<T>(x_2d, &out_2d, axis_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Softmax(const FDTensor& x, FDTensor* out, int axis) {
|
||||||
|
FDASSERT(std::abs(axis) < x.shape.size(),
|
||||||
|
"The given axis should be smaller than the input's dimension");
|
||||||
|
FD_VISIT_FLOAT_TYPES(x.dtype, "SoftmaxKernel",
|
||||||
|
([&] { SoftmaxKernel<data_t>(x, out, axis); }));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace fastdeploy
|
30
csrc/fastdeploy/function/softmax.h
Normal file
30
csrc/fastdeploy/function/softmax.h
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
#ifdef ENABLE_FDTENSOR_FUNC
|
||||||
|
/** Excute the softmax operation for input FDTensor along given dims.
|
||||||
|
@param x The input tensor.
|
||||||
|
@param out The output tensor which stores the result.
|
||||||
|
@param axis The axis to be computed softmax value.
|
||||||
|
*/
|
||||||
|
FASTDEPLOY_DECL void Softmax(const FDTensor& x, FDTensor* out, int axis = -1);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
} // namespace fastdeploy
|
52
csrc/fastdeploy/utils/axis_utils.h
Normal file
52
csrc/fastdeploy/utils/axis_utils.h
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
static inline int CanonicalAxis(const int axis, const int rank) {
|
||||||
|
if (axis < 0) {
|
||||||
|
return axis + rank;
|
||||||
|
}
|
||||||
|
return axis;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int SizeToAxis(const int axis, const std::vector<int64_t>& dims) {
|
||||||
|
int size = 1;
|
||||||
|
for (int i = 0; i < axis; i++) {
|
||||||
|
size *= dims[i];
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int SizeFromAxis(const int axis,
|
||||||
|
const std::vector<int64_t>& dims) {
|
||||||
|
int size = 1;
|
||||||
|
for (int i = axis; i < dims.size(); i++) {
|
||||||
|
size *= dims[i];
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int SizeOutAxis(const int axis,
|
||||||
|
const std::vector<int64_t>& dims) {
|
||||||
|
int size = 1;
|
||||||
|
for (int i = axis + 1; i < dims.size(); i++) {
|
||||||
|
size *= dims[i];
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
@@ -103,36 +103,40 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
|
|||||||
#define FD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
|
#define FD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
|
||||||
FD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
|
FD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
|
||||||
|
|
||||||
#define FD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
|
// Visit different data type to match the corresponding function of FDTensor
|
||||||
[&] { \
|
#define FD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
|
||||||
const auto& __dtype__ = TYPE; \
|
[&] { \
|
||||||
switch (__dtype__) { \
|
const auto& __dtype__ = TYPE; \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \
|
switch (__dtype__) { \
|
||||||
__VA_ARGS__) \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::BOOL, bool, \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
|
__VA_ARGS__) \
|
||||||
__VA_ARGS__) \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT32, int32_t, \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
|
__VA_ARGS__) \
|
||||||
__VA_ARGS__) \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
|
__VA_ARGS__) \
|
||||||
__VA_ARGS__) \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
|
__VA_ARGS__) \
|
||||||
__VA_ARGS__) \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
|
||||||
default: \
|
__VA_ARGS__) \
|
||||||
FDASSERT(false, "Invalid enum data type.") \
|
default: \
|
||||||
} \
|
FDASSERT(false, \
|
||||||
|
"Invalid enum data type. Only accept data type BOOL, INT32, " \
|
||||||
|
"INT64, FP32, FP64.") \
|
||||||
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \
|
#define FD_VISIT_FLOAT_TYPES(TYPE, NAME, ...) \
|
||||||
[&] { \
|
[&] { \
|
||||||
const auto& __dtype__ = TYPE; \
|
const auto& __dtype__ = TYPE; \
|
||||||
switch (__dtype__) { \
|
switch (__dtype__) { \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP32, float, \
|
||||||
__VA_ARGS__) \
|
__VA_ARGS__) \
|
||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::FP64, double, \
|
||||||
__VA_ARGS__) \
|
__VA_ARGS__) \
|
||||||
default: \
|
default: \
|
||||||
FDASSERT(false, "Invalid enum data type.") \
|
FDASSERT(false, \
|
||||||
} \
|
"Invalid enum data type. Only accept data type FP32, FP64.") \
|
||||||
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
#define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \
|
#define FD_VISIT_INT_TYPES(TYPE, NAME, ...) \
|
||||||
@@ -144,7 +148,9 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
|
|||||||
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
|
FD_PRIVATE_CASE_TYPE(NAME, ::fastdeploy::FDDataType::INT64, int64_t, \
|
||||||
__VA_ARGS__) \
|
__VA_ARGS__) \
|
||||||
default: \
|
default: \
|
||||||
FDASSERT(false, "Invalid enum data type.") \
|
FDASSERT( \
|
||||||
|
false, \
|
||||||
|
"Invalid enum data type. Only accept data type INT32, INT64.") \
|
||||||
} \
|
} \
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
# FDTensor C++ 张量化函数
|
# FDTensor C++ 张量化函数
|
||||||
|
|
||||||
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为三类函数:Reduce类函数,Manipulate类函数,Elementwise类函数。
|
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为三类函数:Reduce类函数,Manipulate类函数,Math类函数以及Elementwise类函数。
|
||||||
|
|
||||||
## Reduce类函数
|
## Reduce类函数
|
||||||
|
|
||||||
@@ -239,6 +239,39 @@ input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
|
|||||||
Transpose(input, &output, {1, 0});
|
Transpose(input, &output, {1, 0});
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Math类函数
|
||||||
|
|
||||||
|
目前FastDeploy支持1种Math类函数:Softmax。
|
||||||
|
|
||||||
|
### Softmax
|
||||||
|
|
||||||
|
#### 函数签名
|
||||||
|
|
||||||
|
```c++
|
||||||
|
/** Excute the softmax operation for input FDTensor along given dims.
|
||||||
|
@param x The input tensor.
|
||||||
|
@param out The output tensor which stores the result.
|
||||||
|
@param axis The axis to be computed softmax value.
|
||||||
|
*/
|
||||||
|
void Softmax(const FDTensor& x, FDTensor* out, int axis = -1);
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用示例
|
||||||
|
|
||||||
|
```c++
|
||||||
|
FDTensor input, output;
|
||||||
|
CheckShape check_shape;
|
||||||
|
CheckData check_data;
|
||||||
|
std::vector<float> inputs = {1, 2, 3, 4, 5, 6};
|
||||||
|
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
|
||||||
|
|
||||||
|
// Transpose the input tensor with axis {1, 0}.
|
||||||
|
// The output result would be
|
||||||
|
// [[0.04742587, 0.04742587, 0.04742587],
|
||||||
|
// [0.95257413, 0.95257413, 0.95257413]]
|
||||||
|
Softmax(input, &output, 0);
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Elementwise类函数
|
## Elementwise类函数
|
||||||
|
|
||||||
|
@@ -60,11 +60,12 @@ function(add_fastdeploy_unittest CC_FILE)
|
|||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
if(WITH_TESTING)
|
if(WITH_TESTING)
|
||||||
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||||
add_library(fastdeploy_gtest_main STATIC gtest_main)
|
add_library(fastdeploy_gtest_main STATIC gtest_main)
|
||||||
target_link_libraries(fastdeploy_gtest_main PUBLIC gtest gflags)
|
target_link_libraries(fastdeploy_gtest_main PUBLIC gtest gflags)
|
||||||
message(STATUS "")
|
message(STATUS "")
|
||||||
message(STATUS "*************FastDeploy Unittest Summary**********")
|
message(STATUS "*************FastDeploy Unittest Summary**********")
|
||||||
file(GLOB ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/test_*.cc)
|
file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc)
|
||||||
foreach(_CC_FILE ${ALL_TEST_SRCS})
|
foreach(_CC_FILE ${ALL_TEST_SRCS})
|
||||||
add_fastdeploy_unittest(${_CC_FILE})
|
add_fastdeploy_unittest(${_CC_FILE})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
48
tests/function/test_softmax.cc
Normal file
48
tests/function/test_softmax.cc
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
|
#include "fastdeploy/function/softmax.h"
|
||||||
|
#include "glog/logging.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "gtest_utils.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
#ifdef ENABLE_FDTENSOR_FUNC
|
||||||
|
TEST(fastdeploy, softmax) {
|
||||||
|
FDTensor input, output;
|
||||||
|
CheckShape check_shape;
|
||||||
|
CheckData check_data;
|
||||||
|
std::vector<float> inputs = {1, 2, 3, 4, 5, 6};
|
||||||
|
std::vector<float> expected_result_axis0 = {
|
||||||
|
0.04742587, 0.04742587, 0.04742587, 0.95257413, 0.95257413, 0.95257413};
|
||||||
|
std::vector<float> expected_result_axis1 = {
|
||||||
|
0.09003057, 0.24472846, 0.66524088, 0.09003057, 0.24472846, 0.66524088};
|
||||||
|
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
|
||||||
|
|
||||||
|
// axis = 0
|
||||||
|
Softmax(input, &output, 0);
|
||||||
|
check_shape(output.shape, {2, 3});
|
||||||
|
check_data(reinterpret_cast<const float*>(output.Data()),
|
||||||
|
expected_result_axis0.data(), expected_result_axis0.size());
|
||||||
|
|
||||||
|
// axis = 1
|
||||||
|
Softmax(input, &output, 1);
|
||||||
|
check_shape(output.shape, {2, 3});
|
||||||
|
check_data(reinterpret_cast<const float*>(output.Data()),
|
||||||
|
expected_result_axis1.data(), expected_result_axis1.size());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace fastdeploy
|
Reference in New Issue
Block a user