[Backend] Add AdaptivePool2d for TensorRT plugin (#668)

* add adaptivepool2d for tensorrt plugin

* update code

* update code

* update code to fix bug
This commit is contained in:
yeliang2258
2022-11-25 17:36:59 +08:00
committed by GitHub
parent cc74fb800d
commit d14828cb18
9 changed files with 461 additions and 31 deletions

View File

@@ -164,7 +164,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
file(GLOB_RECURSE DEPLOY_ORT_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cu)
file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/op_cuda_kernels/*.cu)
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
@@ -202,7 +202,7 @@ if(ENABLE_ORT_BACKEND)
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
list(APPEND DEPEND_LIBS external_onnxruntime)
if(WITH_GPU)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_CUDA_SRCS})
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
endif()
endif()
@@ -361,6 +361,7 @@ if(ENABLE_TRT_BACKEND)
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
if(NOT BUILD_ON_JETSON)
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")

View File

@@ -1,15 +1,12 @@
#include "adaptive_pool2d.h"
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <iostream>
#include <vector>
#include <math.h>
#include "adaptive_pool2d_kernel.h"
namespace fastdeploy {
__global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) {
int position = blockDim.x * blockIdx.x + threadIdx.x;
if (position >= edge) return;
if (position >= edge) {
return;
}
int offset = floorf(float(position) / out_bc_offset);
int h = floorf(float(position % out_bc_offset) / ow);
int w = (position % out_bc_offset) % ow;
@@ -17,17 +14,17 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int out_b
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
int wstart = floorf(static_cast<float>(w * iw) / ow);
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
if(is_avg){
if(is_avg) {
out[position] = 0.0;
}else{
} else {
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
}
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int input_idx = h * iw + w;
if(is_avg){
if(is_avg) {
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
}else{
} else {
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
}
}
@@ -40,7 +37,7 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims, const std::vector<
int out_bc_offset = output_dims[2] * output_dims[3];
int in_bc_offset = input_dims[2] * input_dims[3];
int jobs = 1;
for(int i : output_dims){
for(int i : output_dims) {
jobs *= i;
}
bool is_avg = pooling_type == "avg";

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <cstdint>
#include <iostream>
#include <vector>
#include <math.h>
namespace fastdeploy {
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& output_dims,
float* output,
const float* input,
void* compute_stream,
const std::string& pooling_type);
} // namespace fastdeploy

View File

@@ -14,14 +14,9 @@
#ifndef NON_64_PLATFORM
#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h"
#include <algorithm>
#include <cmath>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#include "adaptive_pool2d.h"
namespace fastdeploy {
struct OrtTensorDimensions : std::vector<int64_t> {
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);

View File

@@ -16,19 +16,19 @@
#include <map>
#include <string>
#include <algorithm>
#include <cmath>
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/utils/utils.h"
#ifndef NON_64_PLATFORM
#include "onnxruntime_cxx_api.h" // NOLINT
namespace fastdeploy {
#ifdef WITH_GPU
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& output_dims,
float* output,
const float* input,
void* compute_stream,
const std::string& pooling_type);
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
#endif
namespace fastdeploy {
struct AdaptivePool2dKernel {
protected:
std::string pooling_type_ = "avg";

View File

@@ -0,0 +1,206 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "adaptive_pool2d.h"
namespace fastdeploy {
nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> AdaptivePool2dPluginCreator::mPluginAttributes;
pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n, const void* input, void* output);
AdaptivePool2d::AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type) {
output_size_ = output_size;
pooling_type_ = pooling_type;
}
AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) {
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
output_size_.resize(4);
for(int64_t i =0 ; i < 4; i++){
output_size_[i] =read<int32_t>(d);
}
if(read<int32_t>(d) == 0){
pooling_type_ = "avg";
}else{
pooling_type_ = "max";
}
FDASSERT(d == a + length, "deserialize failed.");
}
int AdaptivePool2d::getNbOutputs() const noexcept {
return 1;
}
nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs,
int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept {
try {
nvinfer1::DimsExprs output(inputs[0]);
output.d[2] = exprBuilder.constant(static_cast<int32_t>(output_size_[2]));
output.d[3] = exprBuilder.constant(static_cast<int32_t>(output_size_[3]));
return output;
}
catch (const std::exception& e) {
FDASSERT(false, "getOutputDimensions failed: %s.",e.what());
}
return nvinfer1::DimsExprs{};
}
int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
return -1;
}
auto const* data = static_cast<float const*>(inputs[0]);
auto* result = static_cast<float*>(outputs[0]);
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2]* outputDesc[0].dims.d[3];
std::vector<int64_t> input_size, output_size;
for(int i =0; i< 4; i++){
input_size.push_back(inputDesc[0].dims.d[i]);
output_size.push_back(outputDesc[0].dims.d[i]);
}
CudaAdaptivePool(input_size, output_size, result, data, stream, pooling_type_);
return cudaPeekAtLastError();
}
size_t AdaptivePool2d::getSerializationSize() const noexcept {
return 5 * sizeof(int32_t) ;
}
void AdaptivePool2d::serialize(void* buffer) const noexcept {
char *d = reinterpret_cast<char*>(buffer), *a = d;
for(int64_t i=0; i< 4; i++){
write(d, output_size_[i]);
}
int32_t pooling_type_val = 0;
if(pooling_type_ != "avg"){
pooling_type_val = 1;
}
write(d, pooling_type_val);
FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()");
}
nvinfer1::DataType AdaptivePool2d::getOutputDataType(
int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept {
return inputType[0];
}
bool AdaptivePool2d::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept {
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
int AdaptivePool2d::initialize() noexcept {
return 0;
}
void AdaptivePool2d::terminate() noexcept {
return;
}
size_t AdaptivePool2d::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept {
return 0;
}
const char* AdaptivePool2d::getPluginType() const noexcept {
return "AdaptivePool2d";
}
const char* AdaptivePool2d::getPluginVersion() const noexcept {
return "1";
}
void AdaptivePool2d::destroy() noexcept {
return;
}
void AdaptivePool2d::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept {
return;
}
nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept {
try{
nvinfer1::IPluginV2DynamicExt* plugin = new AdaptivePool2d(output_size_, pooling_type_);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
catch (std::exception const& e){
FDASSERT(false, "clone failed: %s.",e.what());
}
return nullptr;
}
AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4));
mPluginAttributes.emplace_back(nvinfer1::PluginField("pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char* AdaptivePool2dPluginCreator::getPluginName() const noexcept {
return "AdaptivePool2d";
}
const char* AdaptivePool2dPluginCreator::getPluginVersion() const noexcept {
return "1";
}
const nvinfer1::PluginFieldCollection* AdaptivePool2dPluginCreator::getFieldNames() noexcept {
return &mFC;
}
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept {
try{
const nvinfer1::PluginField* fields = fc->fields;
auto const dims = static_cast<int32_t const*>(fields[0].data);
output_size_.resize(4);
for(int64_t i = 0; i < 4; i++){
output_size_[i] = dims[i];
}
const char* pooling_type_ptr = (static_cast<char const*>(fields[1].data));
std::string pooling_type(pooling_type_ptr, 3);
pooling_type_ = pooling_type;
return new AdaptivePool2d(output_size_, pooling_type_);
}
catch (std::exception const& e){
FDASSERT(false, "createPlugin failed: %s.",e.what());
}
return nullptr;
}
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) noexcept {
try{
return new AdaptivePool2d(serialData, serialLength);
}
catch (std::exception const& e){
FDASSERT(false, "deserializePlugin failed: %s.",e.what());
}
return nullptr;
}
} // namespace fastdeploy

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
#include "common.h" // NOLINT
namespace fastdeploy {
class AdaptivePool2d : public BasePlugin {
public:
AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
AdaptivePool2d(const void* buffer, size_t length);
~AdaptivePool2d() override = default;
int getNbOutputs() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
nvinfer1::DataType getOutputDataType(
int index,
const nvinfer1::DataType* inputType,
int nbInputs) const noexcept override;
bool supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) noexcept override;
void destroy() noexcept override;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
private:
std::vector<int32_t> output_size_;
std::string pooling_type_;
};
class AdaptivePool2dPluginCreator : public BaseCreator {
public:
AdaptivePool2dPluginCreator();
~AdaptivePool2dPluginCreator() override = default;
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override;
nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::vector<int32_t> output_size_;
std::string pooling_type_;
};
REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
} // namespace fastdeploy

View File

@@ -0,0 +1,80 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "NvInferPlugin.h"
#include "NvInferRuntimeCommon.h"
#include "fastdeploy/utils/utils.h"
#include <iostream>
#include <string>
#include <vector>
#include <memory>
#include <cstring>
#include <sstream>
namespace fastdeploy {
class BasePlugin : public nvinfer1::IPluginV2DynamicExt {
protected:
void setPluginNamespace(const char* libNamespace) noexcept override {
mNamespace = libNamespace;
}
const char* getPluginNamespace() const noexcept override {
return mNamespace.c_str();
}
std::string mNamespace;
};
class BaseCreator : public nvinfer1::IPluginCreator {
public:
void setPluginNamespace(const char* libNamespace) noexcept override {
mNamespace = libNamespace;
}
const char* getPluginNamespace() const noexcept override {
return mNamespace.c_str();
}
protected:
std::string mNamespace;
};
typedef enum {
STATUS_SUCCESS = 0,
STATUS_FAILURE = 1,
STATUS_BAD_PARAM = 2,
STATUS_NOT_SUPPORTED = 3,
STATUS_NOT_INITIALIZED = 4
} pluginStatus_t;
// Write values into buffer
template <typename T>
void write(char*& buffer, const T& val) {
std::memcpy(buffer, &val, sizeof(T));
buffer += sizeof(T);
}
// Read values from buffer
template <typename T>
T read(const char*& buffer) {
T val{};
std::memcpy(&val, buffer, sizeof(T));
buffer += sizeof(T);
return val;
}
} // namespace fastdeploy

View File

@@ -124,14 +124,18 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
option_ = option;
#ifdef ENABLE_PADDLE_FRONTEND
std::vector<paddle2onnx::CustomOp> ops;
ops.resize(1);
strcpy(ops[0].op_name, "pool2d");
strcpy(ops[0].export_op_name, "AdaptivePool2d");
char* model_content_ptr;
int model_content_size = 0;
char* calibration_cache_ptr;
int calibration_cache_size = 0;
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
&model_content_ptr, &model_content_size, 11, true,
verbose, true, true, true, nullptr,
0, "tensorrt",
verbose, true, true, true, ops.data(),
1, "tensorrt",
&calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) {
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
<< std::endl;