mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 18:11:00 +08:00
Add Transpose function (#91)
* Add Transpose function * csrcs->csrc * Add transpose unittest * Add reduce_max_large_dim unittest
This commit is contained in:
@@ -120,7 +120,7 @@ void FDTensor::PrintInfo(const std::string& prefix) {
|
||||
} else {
|
||||
FDASSERT(false,
|
||||
"PrintInfo function doesn't support current situation, maybe you "
|
||||
"need enhance this function now.")
|
||||
"need enhance this function now.");
|
||||
}
|
||||
std::cout << prefix << ": shape=";
|
||||
for (int i = 0; i < shape.size(); ++i) {
|
||||
|
@@ -12,11 +12,13 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/function/reduce.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "fastdeploy/function/eigen.h"
|
||||
#include "fastdeploy/function/reduce.h"
|
||||
#include "fastdeploy/function/reduce_functor.h"
|
||||
#include "fastdeploy/function/transpose.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -71,7 +73,7 @@ void ReduceFunctor(const FDTensor& input, FDTensor* output,
|
||||
inline void GetShuffledDim(const std::vector<int64_t>& src_dims,
|
||||
std::vector<int64_t>* dst_dims,
|
||||
const std::vector<int64_t>& reduced_dims,
|
||||
std::vector<int>* perm_axis) {
|
||||
std::vector<int64_t>* perm_axis) {
|
||||
// check if it's a reduced dim
|
||||
std::vector<bool> src_dims_check(src_dims.size(), false);
|
||||
size_t src_size = src_dims.size();
|
||||
@@ -104,19 +106,33 @@ template <typename OutT>
|
||||
void GetShuffledInput(const FDTensor& input, FDTensor* shuffled_input,
|
||||
const std::vector<int64_t>& dims) {
|
||||
auto shuffled_dims = input.shape;
|
||||
std::vector<int> perm_axis(input.shape.size());
|
||||
std::vector<int64_t> perm_axis(input.shape.size());
|
||||
GetShuffledDim(input.shape, &shuffled_dims, dims, &perm_axis);
|
||||
|
||||
shuffled_input->Allocate(shuffled_dims, input.dtype);
|
||||
// TODO(zhoushunjie) : Need to implement trans function
|
||||
// phi::funcs::TransposeNormal<DeviceContext, OutT> trans;
|
||||
// trans(dev_ctx, input, shuffled_input, perm_axis);
|
||||
Transpose(input, shuffled_input, perm_axis);
|
||||
}
|
||||
|
||||
//////////////// HandleLargeDim
|
||||
template <typename OutT, typename Functor>
|
||||
void HandleLargeDim(const FDTensor& input, FDTensor* output,
|
||||
const std::vector<int64_t>& dims, bool keep_dim) {
|
||||
auto out_dims = input.shape;
|
||||
std::vector<int64_t> dims_ref = dims;
|
||||
auto x_rank = input.shape.size();
|
||||
for (size_t i = 0; i < dims_ref.size(); ++i) {
|
||||
if (dims_ref[i] < 0) dims_ref[i] = x_rank + dims_ref[i];
|
||||
out_dims[dims_ref[i]] = 1;
|
||||
}
|
||||
if (!keep_dim) {
|
||||
const int kDelFlag = -2;
|
||||
for (size_t i = 0; i < dims_ref.size(); ++i) {
|
||||
out_dims[dims_ref[i]] = kDelFlag;
|
||||
}
|
||||
out_dims.erase(remove(out_dims.begin(), out_dims.end(), kDelFlag),
|
||||
out_dims.end());
|
||||
}
|
||||
output->Allocate(out_dims, TypeToDataType<OutT>::dtype);
|
||||
// shuffle the reduced dim to the end
|
||||
FDTensor shuffled_input;
|
||||
GetShuffledInput<OutT>(input, &shuffled_input, dims);
|
||||
@@ -126,11 +142,9 @@ void HandleLargeDim(const FDTensor& input, FDTensor* output,
|
||||
const int64_t reduced = shuffled_input.Numel() / unreduced;
|
||||
shuffled_input.Allocate({unreduced, reduced}, TypeToDataType<OutT>::dtype);
|
||||
|
||||
auto output_dim = output->shape;
|
||||
output->Allocate({unreduced}, TypeToDataType<OutT>::dtype);
|
||||
|
||||
output->shape = {unreduced};
|
||||
ReduceFunctor<OutT, 2, 1, Functor>(shuffled_input, output, {1}, keep_dim);
|
||||
output->shape = output_dim;
|
||||
output->shape = out_dims;
|
||||
}
|
||||
|
||||
////////////// ReduceKernel
|
||||
@@ -152,7 +166,7 @@ void ReduceKernelImpl(const FDTensor& input, FDTensor* output,
|
||||
} else {
|
||||
int ndim = input.shape.size();
|
||||
int rdim = dims.size();
|
||||
if (ndim > 3) {
|
||||
if (ndim > 4) {
|
||||
HandleLargeDim<OutT, Functor>(input, output, dims, keep_dim);
|
||||
} else {
|
||||
HANDLE_REDUCE_DIM(4, 3);
|
||||
|
115
csrc/fastdeploy/function/transpose.cc
Normal file
115
csrc/fastdeploy/function/transpose.cc
Normal file
@@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/function/transpose.h"
|
||||
#include "fastdeploy/function/eigen.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
#ifdef ENABLE_FDTENSOR_FUNC
|
||||
|
||||
template <typename T>
|
||||
struct TransposeNormalKernel {
|
||||
void operator()(const FDTensor& in, FDTensor* out,
|
||||
const std::vector<int64_t>& axis) {
|
||||
const int rank = axis.size();
|
||||
auto in_stride = GetStride(in.shape);
|
||||
auto out_stride = GetStride(out->shape);
|
||||
const T* in_ptr = reinterpret_cast<const T*>(in.Data());
|
||||
T* out_ptr = reinterpret_cast<T*>(out->Data());
|
||||
|
||||
auto transpose_helper = [&](int64_t beg, int64_t end) {
|
||||
for (int64_t out_idx = beg; out_idx < end; ++out_idx) {
|
||||
int64_t in_idx = 0;
|
||||
int64_t tmp_idx = out_idx;
|
||||
// calculate the input index
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
const int64_t coordinate = tmp_idx / out_stride[i];
|
||||
tmp_idx -= coordinate * out_stride[i];
|
||||
in_idx += coordinate * in_stride[axis[i]];
|
||||
}
|
||||
out_ptr[out_idx] = in_ptr[in_idx];
|
||||
}
|
||||
};
|
||||
transpose_helper(0, out->Numel());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int Rank>
|
||||
struct TransposeKernelImpl {
|
||||
void operator()(const FDTensor& in, FDTensor* out,
|
||||
const std::vector<int64_t>& axis) {
|
||||
Eigen::array<int, Rank> permute;
|
||||
for (int i = 0; i < Rank; i++) {
|
||||
permute[i] = axis[i];
|
||||
}
|
||||
|
||||
auto& place = *EigenDeviceWrapper::GetInstance()->GetDevice();
|
||||
auto eigen_in = EigenTensor<T, Rank>::From(in);
|
||||
auto eigen_out = EigenTensor<T, Rank>::From(*out);
|
||||
eigen_out.device(place) = eigen_in.shuffle(permute);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void TransposeKernel(const FDTensor& x, FDTensor* out,
|
||||
const std::vector<int64_t>& axis) {
|
||||
int rank = axis.size();
|
||||
switch (rank) {
|
||||
case 1:
|
||||
TransposeKernelImpl<T, 1> trans1;
|
||||
trans1(x, out, axis);
|
||||
break;
|
||||
case 2:
|
||||
TransposeKernelImpl<T, 2> trans2;
|
||||
trans2(x, out, axis);
|
||||
break;
|
||||
case 3:
|
||||
TransposeKernelImpl<T, 3> trans3;
|
||||
trans3(x, out, axis);
|
||||
break;
|
||||
case 4:
|
||||
TransposeKernelImpl<T, 4> trans4;
|
||||
trans4(x, out, axis);
|
||||
break;
|
||||
default:
|
||||
// for rank >= 4 situation
|
||||
TransposeNormalKernel<T> trans_normal;
|
||||
trans_normal(x, out, axis);
|
||||
}
|
||||
}
|
||||
|
||||
void Transpose(const FDTensor& x, FDTensor* out,
|
||||
const std::vector<int64_t>& dims) {
|
||||
size_t dims_size = dims.size();
|
||||
FDASSERT(dims_size == x.shape.size(),
|
||||
"The input tensor's dimension should be equal to the dims's size.");
|
||||
std::vector<int> count(dims_size, 0);
|
||||
for (size_t i = 0; i < dims_size; i++) {
|
||||
FDASSERT(dims[i] >= 0, "The dims should be greater than or equal to 0.");
|
||||
FDASSERT(dims[i] < static_cast<int>(dims_size) && ++count[dims[i]] == 1,
|
||||
"Each element of Attribute axis should be a unique value range "
|
||||
"from 0 to (dims - 1), where the dims is the axis's size, unique "
|
||||
"value means this axis value can appear only once. ");
|
||||
}
|
||||
std::vector<int64_t> out_dims(dims_size);
|
||||
for (size_t i = 0; i < dims_size; i++) {
|
||||
out_dims[i] = x.shape[dims[i]];
|
||||
}
|
||||
out->Allocate(out_dims, x.dtype);
|
||||
FD_VISIT_ALL_TYPES(x.dtype, "TransposeKernel",
|
||||
([&] { TransposeKernel<data_t>(x, out, dims); }));
|
||||
}
|
||||
#endif
|
||||
} // namespace fastdeploy
|
29
csrc/fastdeploy/function/transpose.h
Normal file
29
csrc/fastdeploy/function/transpose.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
#ifdef ENABLE_FDTENSOR_FUNC
|
||||
/** Excute the transpose operation for input FDTensor along given dims.
|
||||
@param x The input tensor.
|
||||
@param out The output tensor which stores the result.
|
||||
@param dims The vector of axis which the input tensor will transpose.
|
||||
*/
|
||||
FASTDEPLOY_DECL void Transpose(const FDTensor& x, FDTensor* out,
|
||||
const std::vector<int64_t>& dims);
|
||||
#endif
|
||||
} // namespace fastdeploy
|
@@ -46,4 +46,13 @@ bool ReadBinaryFromFile(const std::string& file, std::string* contents) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> GetStride(const std::vector<int64_t>& dims) {
|
||||
auto dims_size = dims.size();
|
||||
std::vector<int64_t> result(dims_size, 1);
|
||||
for (int i = dims_size - 2; i >= 0; --i) {
|
||||
result[i] = result[i + 1] * dims[i + 1];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -20,6 +20,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if defined(_WIN32)
|
||||
#ifdef FASTDEPLOY_LIB
|
||||
@@ -147,4 +148,7 @@ FASTDEPLOY_DECL bool ReadBinaryFromFile(const std::string& file,
|
||||
} \
|
||||
}()
|
||||
|
||||
FASTDEPLOY_DECL std::vector<int64_t> GetStride(
|
||||
const std::vector<int64_t>& dims);
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -1,6 +1,6 @@
|
||||
# FDTensor C++ 张量化函数
|
||||
|
||||
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为两类函数:Reduce类函数和Elementwise类函数。
|
||||
FDTensor是FastDeploy在C++层表示张量的结构体。该结构体主要用于管理推理部署时模型的输入输出数据,支持在不同的Runtime后端中使用。在基于C++的推理部署应用开发过程中,我们往往需要对输入输出的数据进行一些数据处理,用以得到模型的实际输入或者应用的实际输出。这种数据预处理的逻辑可以使用原生的C++标准库来实现,但开发难度会比较大,如对3维Tensor的第2维求最大值。针对这个问题,FastDeploy基于FDTensor开发了一套C++张量化函数,用于降低FastDeploy用户的开发成本,提高开发效率。目前主要分为三类函数:Reduce类函数,Manipulate类函数,Elementwise类函数。
|
||||
|
||||
## Reduce类函数
|
||||
|
||||
@@ -209,6 +209,37 @@ input.SetExternalData({2, 3}, FDDataType::INT32, bool_inputs.data());
|
||||
All(input, &output, {0}, /* keep_dim = */true);
|
||||
```
|
||||
|
||||
## Manipulate类函数
|
||||
|
||||
目前FastDeploy支持1种Manipulate类函数:Transpose。
|
||||
|
||||
### Transpose
|
||||
|
||||
#### 函数签名
|
||||
|
||||
```c++
|
||||
/** Excute the transpose operation for input FDTensor along given dims.
|
||||
@param x The input tensor.
|
||||
@param out The output tensor which stores the result.
|
||||
@param dims The vector of axis which the input tensor will transpose.
|
||||
*/
|
||||
void Transpose(const FDTensor& x, FDTensor* out,
|
||||
const std::vector<int64_t>& dims);
|
||||
```
|
||||
|
||||
#### 使用示例
|
||||
|
||||
```c++
|
||||
FDTensor input, output;
|
||||
std::vector<float> inputs = {2, 4, 3, 7, 1, 5};
|
||||
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
|
||||
|
||||
// Transpose the input tensor with axis {1, 0}.
|
||||
// The output result would be [[2, 7], [4, 1], [3, 5]]
|
||||
Transpose(input, &output, {1, 0});
|
||||
```
|
||||
|
||||
|
||||
## Elementwise类函数
|
||||
|
||||
正在开发中,敬请关注······
|
||||
|
@@ -59,6 +59,28 @@ TEST(fastdeploy, reduce_max) {
|
||||
expected_result_noaxis.data(), expected_result_noaxis.size());
|
||||
}
|
||||
|
||||
TEST(fastdeploy, reduce_max_large_dim) {
|
||||
FDTensor input, output;
|
||||
CheckShape check_shape;
|
||||
CheckData check_data;
|
||||
|
||||
std::vector<int> inputs = {2, 4, 3, 7, 1, 5, 6, 9};
|
||||
std::vector<int> expected_result_axis0 = {4, 7, 5, 9};
|
||||
input.SetExternalData({2, 1, 2, 1, 2}, FDDataType::INT32, inputs.data());
|
||||
|
||||
// keep_dim = true, reduce_all = false
|
||||
Max(input, &output, {4}, true);
|
||||
check_shape(output.shape, {2, 1, 2, 1, 1});
|
||||
check_data(reinterpret_cast<const int*>(output.Data()),
|
||||
expected_result_axis0.data(), expected_result_axis0.size());
|
||||
|
||||
// keep_dim = false, reduce_all = false
|
||||
Max(input, &output, {4});
|
||||
check_shape(output.shape, {2, 1, 2, 1});
|
||||
check_data(reinterpret_cast<const int*>(output.Data()),
|
||||
expected_result_axis0.data(), expected_result_axis0.size());
|
||||
}
|
||||
|
||||
TEST(fastdeploy, reduce_min) {
|
||||
FDTensor input, output;
|
||||
CheckShape check_shape;
|
||||
|
61
tests/test_transpose.cc
Normal file
61
tests/test_transpose.cc
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/function/transpose.h"
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "gtest_utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
#ifdef ENABLE_FDTENSOR_FUNC
|
||||
TEST(fastdeploy, transpose_2d) {
|
||||
FDTensor input, output;
|
||||
CheckShape check_shape;
|
||||
CheckData check_data;
|
||||
|
||||
std::vector<float> inputs = {2, 4, 3, 7, 1, 5};
|
||||
std::vector<float> expected_result = {2, 7, 4, 1, 3, 5};
|
||||
input.SetExternalData({2, 3}, FDDataType::FP32, inputs.data());
|
||||
|
||||
Transpose(input, &output, {1, 0});
|
||||
check_shape(output.shape, {3, 2});
|
||||
check_data(reinterpret_cast<const float*>(output.Data()),
|
||||
expected_result.data(), expected_result.size());
|
||||
}
|
||||
|
||||
TEST(fastdeploy, transpose_5d) {
|
||||
FDTensor input, output;
|
||||
CheckShape check_shape;
|
||||
CheckData check_data;
|
||||
|
||||
std::vector<int64_t> input_shape = {2, 1, 3, 1, 2};
|
||||
auto total_size = std::accumulate(input_shape.begin(), input_shape.end(), 1,
|
||||
std::multiplies<int64_t>());
|
||||
std::vector<int> inputs(total_size, 1);
|
||||
std::iota(inputs.begin(), inputs.end(), 1);
|
||||
std::vector<int> expected_result = {1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12};
|
||||
input.SetExternalData(input_shape, FDDataType::INT32, inputs.data());
|
||||
|
||||
Transpose(input, &output, {0, 1, 4, 3, 2});
|
||||
check_shape(output.shape, {2, 1, 2, 1, 3});
|
||||
check_data(reinterpret_cast<const int*>(output.Data()),
|
||||
expected_result.data(), expected_result.size());
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace fastdeploy
|
Reference in New Issue
Block a user