mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
[Backend] Add AdaptivePool2d for TensorRT plugin (#668)
* add adaptivepool2d for tensorrt plugin * update code * update code * update code to fix bug
This commit is contained in:
@@ -164,7 +164,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc
|
|||||||
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
|
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
|
||||||
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
|
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
|
||||||
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
|
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
|
||||||
file(GLOB_RECURSE DEPLOY_ORT_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cu)
|
file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/op_cuda_kernels/*.cu)
|
||||||
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
|
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
|
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
|
||||||
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
|
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
|
||||||
@@ -202,7 +202,7 @@ if(ENABLE_ORT_BACKEND)
|
|||||||
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
|
include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake)
|
||||||
list(APPEND DEPEND_LIBS external_onnxruntime)
|
list(APPEND DEPEND_LIBS external_onnxruntime)
|
||||||
if(WITH_GPU)
|
if(WITH_GPU)
|
||||||
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_CUDA_SRCS})
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -361,6 +361,7 @@ if(ENABLE_TRT_BACKEND)
|
|||||||
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
||||||
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH)
|
||||||
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
|
list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB})
|
||||||
|
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS})
|
||||||
|
|
||||||
if(NOT BUILD_ON_JETSON)
|
if(NOT BUILD_ON_JETSON)
|
||||||
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")
|
if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt")
|
||||||
|
@@ -1,15 +1,12 @@
|
|||||||
#include "adaptive_pool2d.h"
|
#include "adaptive_pool2d_kernel.h"
|
||||||
#include <cuda.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <iostream>
|
|
||||||
#include <vector>
|
|
||||||
#include <math.h>
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
__global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) {
|
__global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) {
|
||||||
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (position >= edge) return;
|
if (position >= edge) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
int offset = floorf(float(position) / out_bc_offset);
|
int offset = floorf(float(position) / out_bc_offset);
|
||||||
int h = floorf(float(position % out_bc_offset) / ow);
|
int h = floorf(float(position % out_bc_offset) / ow);
|
||||||
int w = (position % out_bc_offset) % ow;
|
int w = (position % out_bc_offset) % ow;
|
||||||
@@ -17,17 +14,17 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int out_b
|
|||||||
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
|
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
|
||||||
int wstart = floorf(static_cast<float>(w * iw) / ow);
|
int wstart = floorf(static_cast<float>(w * iw) / ow);
|
||||||
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
|
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
|
||||||
if(is_avg){
|
if(is_avg) {
|
||||||
out[position] = 0.0;
|
out[position] = 0.0;
|
||||||
}else{
|
} else {
|
||||||
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
|
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
|
||||||
}
|
}
|
||||||
for (int h = hstart; h < hend; ++h) {
|
for (int h = hstart; h < hend; ++h) {
|
||||||
for (int w = wstart; w < wend; ++w) {
|
for (int w = wstart; w < wend; ++w) {
|
||||||
int input_idx = h * iw + w;
|
int input_idx = h * iw + w;
|
||||||
if(is_avg){
|
if(is_avg) {
|
||||||
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
|
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
|
||||||
}else{
|
} else {
|
||||||
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
|
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -40,7 +37,7 @@ void CudaAdaptivePool(const std::vector<int64_t>& input_dims, const std::vector<
|
|||||||
int out_bc_offset = output_dims[2] * output_dims[3];
|
int out_bc_offset = output_dims[2] * output_dims[3];
|
||||||
int in_bc_offset = input_dims[2] * input_dims[3];
|
int in_bc_offset = input_dims[2] * input_dims[3];
|
||||||
int jobs = 1;
|
int jobs = 1;
|
||||||
for(int i : output_dims){
|
for(int i : output_dims) {
|
||||||
jobs *= i;
|
jobs *= i;
|
||||||
}
|
}
|
||||||
bool is_avg = pooling_type == "avg";
|
bool is_avg = pooling_type == "avg";
|
35
fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h
Executable file
35
fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h
Executable file
@@ -0,0 +1,35 @@
|
|||||||
|
|
||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <iostream>
|
||||||
|
#include <vector>
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
||||||
|
const std::vector<int64_t>& output_dims,
|
||||||
|
float* output,
|
||||||
|
const float* input,
|
||||||
|
void* compute_stream,
|
||||||
|
const std::string& pooling_type);
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
@@ -14,14 +14,9 @@
|
|||||||
|
|
||||||
#ifndef NON_64_PLATFORM
|
#ifndef NON_64_PLATFORM
|
||||||
|
|
||||||
#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h"
|
#include "adaptive_pool2d.h"
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
|
||||||
#include "fastdeploy/utils/utils.h"
|
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
struct OrtTensorDimensions : std::vector<int64_t> {
|
struct OrtTensorDimensions : std::vector<int64_t> {
|
||||||
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) {
|
||||||
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
|
||||||
|
@@ -16,19 +16,19 @@
|
|||||||
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
|
||||||
#ifndef NON_64_PLATFORM
|
#ifndef NON_64_PLATFORM
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
namespace fastdeploy {
|
|
||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
|
||||||
const std::vector<int64_t>& output_dims,
|
|
||||||
float* output,
|
|
||||||
const float* input,
|
|
||||||
void* compute_stream,
|
|
||||||
const std::string& pooling_type);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
struct AdaptivePool2dKernel {
|
struct AdaptivePool2dKernel {
|
||||||
protected:
|
protected:
|
||||||
std::string pooling_type_ = "avg";
|
std::string pooling_type_ = "avg";
|
||||||
|
206
fastdeploy/backends/tensorrt/ops/adaptive_pool2d.cc
Executable file
206
fastdeploy/backends/tensorrt/ops/adaptive_pool2d.cc
Executable file
@@ -0,0 +1,206 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "adaptive_pool2d.h"
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{};
|
||||||
|
std::vector<nvinfer1::PluginField> AdaptivePool2dPluginCreator::mPluginAttributes;
|
||||||
|
|
||||||
|
pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n, const void* input, void* output);
|
||||||
|
|
||||||
|
AdaptivePool2d::AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type) {
|
||||||
|
output_size_ = output_size;
|
||||||
|
pooling_type_ = pooling_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) {
|
||||||
|
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
|
||||||
|
output_size_.resize(4);
|
||||||
|
for(int64_t i =0 ; i < 4; i++){
|
||||||
|
output_size_[i] =read<int32_t>(d);
|
||||||
|
}
|
||||||
|
if(read<int32_t>(d) == 0){
|
||||||
|
pooling_type_ = "avg";
|
||||||
|
}else{
|
||||||
|
pooling_type_ = "max";
|
||||||
|
}
|
||||||
|
FDASSERT(d == a + length, "deserialize failed.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdaptivePool2d::getNbOutputs() const noexcept {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions(
|
||||||
|
int outputIndex, const nvinfer1::DimsExprs* inputs,
|
||||||
|
int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept {
|
||||||
|
try {
|
||||||
|
nvinfer1::DimsExprs output(inputs[0]);
|
||||||
|
output.d[2] = exprBuilder.constant(static_cast<int32_t>(output_size_[2]));
|
||||||
|
output.d[3] = exprBuilder.constant(static_cast<int32_t>(output_size_[3]));
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
catch (const std::exception& e) {
|
||||||
|
FDASSERT(false, "getOutputDimensions failed: %s.",e.what());
|
||||||
|
}
|
||||||
|
return nvinfer1::DimsExprs{};
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs,
|
||||||
|
void* const* outputs,
|
||||||
|
void* workspace,
|
||||||
|
cudaStream_t stream) noexcept {
|
||||||
|
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
auto const* data = static_cast<float const*>(inputs[0]);
|
||||||
|
auto* result = static_cast<float*>(outputs[0]);
|
||||||
|
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2]* outputDesc[0].dims.d[3];
|
||||||
|
std::vector<int64_t> input_size, output_size;
|
||||||
|
for(int i =0; i< 4; i++){
|
||||||
|
input_size.push_back(inputDesc[0].dims.d[i]);
|
||||||
|
output_size.push_back(outputDesc[0].dims.d[i]);
|
||||||
|
}
|
||||||
|
CudaAdaptivePool(input_size, output_size, result, data, stream, pooling_type_);
|
||||||
|
return cudaPeekAtLastError();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t AdaptivePool2d::getSerializationSize() const noexcept {
|
||||||
|
return 5 * sizeof(int32_t) ;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdaptivePool2d::serialize(void* buffer) const noexcept {
|
||||||
|
char *d = reinterpret_cast<char*>(buffer), *a = d;
|
||||||
|
for(int64_t i=0; i< 4; i++){
|
||||||
|
write(d, output_size_[i]);
|
||||||
|
}
|
||||||
|
int32_t pooling_type_val = 0;
|
||||||
|
if(pooling_type_ != "avg"){
|
||||||
|
pooling_type_val = 1;
|
||||||
|
}
|
||||||
|
write(d, pooling_type_val);
|
||||||
|
FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()");
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DataType AdaptivePool2d::getOutputDataType(
|
||||||
|
int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept {
|
||||||
|
return inputType[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AdaptivePool2d::supportsFormatCombination(
|
||||||
|
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept {
|
||||||
|
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
|
||||||
|
}
|
||||||
|
|
||||||
|
int AdaptivePool2d::initialize() noexcept {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdaptivePool2d::terminate() noexcept {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t AdaptivePool2d::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const noexcept {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* AdaptivePool2d::getPluginType() const noexcept {
|
||||||
|
return "AdaptivePool2d";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* AdaptivePool2d::getPluginVersion() const noexcept {
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdaptivePool2d::destroy() noexcept {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
void AdaptivePool2d::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept {
|
||||||
|
try{
|
||||||
|
nvinfer1::IPluginV2DynamicExt* plugin = new AdaptivePool2d(output_size_, pooling_type_);
|
||||||
|
plugin->setPluginNamespace(mNamespace.c_str());
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
catch (std::exception const& e){
|
||||||
|
FDASSERT(false, "clone failed: %s.",e.what());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() {
|
||||||
|
mPluginAttributes.clear();
|
||||||
|
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4));
|
||||||
|
mPluginAttributes.emplace_back(nvinfer1::PluginField("pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3));
|
||||||
|
|
||||||
|
mFC.nbFields = mPluginAttributes.size();
|
||||||
|
mFC.fields = mPluginAttributes.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* AdaptivePool2dPluginCreator::getPluginName() const noexcept {
|
||||||
|
return "AdaptivePool2d";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* AdaptivePool2dPluginCreator::getPluginVersion() const noexcept {
|
||||||
|
return "1";
|
||||||
|
}
|
||||||
|
|
||||||
|
const nvinfer1::PluginFieldCollection* AdaptivePool2dPluginCreator::getFieldNames() noexcept {
|
||||||
|
return &mFC;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const char* name,
|
||||||
|
const nvinfer1::PluginFieldCollection* fc) noexcept {
|
||||||
|
try{
|
||||||
|
const nvinfer1::PluginField* fields = fc->fields;
|
||||||
|
auto const dims = static_cast<int32_t const*>(fields[0].data);
|
||||||
|
output_size_.resize(4);
|
||||||
|
for(int64_t i = 0; i < 4; i++){
|
||||||
|
output_size_[i] = dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* pooling_type_ptr = (static_cast<char const*>(fields[1].data));
|
||||||
|
std::string pooling_type(pooling_type_ptr, 3);
|
||||||
|
pooling_type_ = pooling_type;
|
||||||
|
return new AdaptivePool2d(output_size_, pooling_type_);
|
||||||
|
}
|
||||||
|
catch (std::exception const& e){
|
||||||
|
FDASSERT(false, "createPlugin failed: %s.",e.what());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(const char* name,
|
||||||
|
const void* serialData,
|
||||||
|
size_t serialLength) noexcept {
|
||||||
|
try{
|
||||||
|
return new AdaptivePool2d(serialData, serialLength);
|
||||||
|
}
|
||||||
|
catch (std::exception const& e){
|
||||||
|
FDASSERT(false, "deserializePlugin failed: %s.",e.what());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
112
fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h
Executable file
112
fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h
Executable file
@@ -0,0 +1,112 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
|
||||||
|
#include "common.h" // NOLINT
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
class AdaptivePool2d : public BasePlugin {
|
||||||
|
public:
|
||||||
|
AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
|
||||||
|
|
||||||
|
AdaptivePool2d(const void* buffer, size_t length);
|
||||||
|
|
||||||
|
~AdaptivePool2d() override = default;
|
||||||
|
|
||||||
|
int getNbOutputs() const noexcept override;
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(
|
||||||
|
int outputIndex,
|
||||||
|
const nvinfer1::DimsExprs* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||||
|
|
||||||
|
nvinfer1::DataType getOutputDataType(
|
||||||
|
int index,
|
||||||
|
const nvinfer1::DataType* inputType,
|
||||||
|
int nbInputs) const noexcept override;
|
||||||
|
|
||||||
|
bool supportsFormatCombination(
|
||||||
|
int pos,
|
||||||
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
|
int nbInputs,
|
||||||
|
int nbOutputs) noexcept override;
|
||||||
|
|
||||||
|
int initialize() noexcept override;
|
||||||
|
|
||||||
|
void terminate() noexcept override;
|
||||||
|
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
|
int nbOutputs) const noexcept override;
|
||||||
|
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
|
const void* const* inputs,
|
||||||
|
void* const* outputs,
|
||||||
|
void* workspace,
|
||||||
|
cudaStream_t stream) noexcept override;
|
||||||
|
|
||||||
|
size_t getSerializationSize() const noexcept override;
|
||||||
|
|
||||||
|
void serialize(void* buffer) const noexcept override;
|
||||||
|
|
||||||
|
const char* getPluginType() const noexcept override;
|
||||||
|
|
||||||
|
const char* getPluginVersion() const noexcept override;
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
|
int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
|
int nbOutputs) noexcept override;
|
||||||
|
void destroy() noexcept override;
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<int32_t> output_size_;
|
||||||
|
std::string pooling_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AdaptivePool2dPluginCreator : public BaseCreator {
|
||||||
|
public:
|
||||||
|
AdaptivePool2dPluginCreator();
|
||||||
|
|
||||||
|
~AdaptivePool2dPluginCreator() override = default;
|
||||||
|
|
||||||
|
const char* getPluginName() const noexcept override;
|
||||||
|
|
||||||
|
const char* getPluginVersion() const noexcept override;
|
||||||
|
|
||||||
|
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name,
|
||||||
|
const nvinfer1::PluginFieldCollection* fc) noexcept override;
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name,
|
||||||
|
const void* serialData,
|
||||||
|
size_t serialLength) noexcept override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static nvinfer1::PluginFieldCollection mFC;
|
||||||
|
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||||
|
std::vector<int32_t> output_size_;
|
||||||
|
std::string pooling_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
80
fastdeploy/backends/tensorrt/ops/common.h
Executable file
80
fastdeploy/backends/tensorrt/ops/common.h
Executable file
@@ -0,0 +1,80 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "NvInferPlugin.h"
|
||||||
|
#include "NvInferRuntimeCommon.h"
|
||||||
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <cstring>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
class BasePlugin : public nvinfer1::IPluginV2DynamicExt {
|
||||||
|
protected:
|
||||||
|
void setPluginNamespace(const char* libNamespace) noexcept override {
|
||||||
|
mNamespace = libNamespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginNamespace() const noexcept override {
|
||||||
|
return mNamespace.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string mNamespace;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BaseCreator : public nvinfer1::IPluginCreator {
|
||||||
|
public:
|
||||||
|
void setPluginNamespace(const char* libNamespace) noexcept override {
|
||||||
|
mNamespace = libNamespace;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char* getPluginNamespace() const noexcept override {
|
||||||
|
return mNamespace.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::string mNamespace;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
STATUS_SUCCESS = 0,
|
||||||
|
STATUS_FAILURE = 1,
|
||||||
|
STATUS_BAD_PARAM = 2,
|
||||||
|
STATUS_NOT_SUPPORTED = 3,
|
||||||
|
STATUS_NOT_INITIALIZED = 4
|
||||||
|
} pluginStatus_t;
|
||||||
|
|
||||||
|
// Write values into buffer
|
||||||
|
template <typename T>
|
||||||
|
void write(char*& buffer, const T& val) {
|
||||||
|
std::memcpy(buffer, &val, sizeof(T));
|
||||||
|
buffer += sizeof(T);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read values from buffer
|
||||||
|
template <typename T>
|
||||||
|
T read(const char*& buffer) {
|
||||||
|
T val{};
|
||||||
|
std::memcpy(&val, buffer, sizeof(T));
|
||||||
|
buffer += sizeof(T);
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
@@ -124,14 +124,18 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
option_ = option;
|
option_ = option;
|
||||||
|
|
||||||
#ifdef ENABLE_PADDLE_FRONTEND
|
#ifdef ENABLE_PADDLE_FRONTEND
|
||||||
|
std::vector<paddle2onnx::CustomOp> ops;
|
||||||
|
ops.resize(1);
|
||||||
|
strcpy(ops[0].op_name, "pool2d");
|
||||||
|
strcpy(ops[0].export_op_name, "AdaptivePool2d");
|
||||||
char* model_content_ptr;
|
char* model_content_ptr;
|
||||||
int model_content_size = 0;
|
int model_content_size = 0;
|
||||||
char* calibration_cache_ptr;
|
char* calibration_cache_ptr;
|
||||||
int calibration_cache_size = 0;
|
int calibration_cache_size = 0;
|
||||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||||
&model_content_ptr, &model_content_size, 11, true,
|
&model_content_ptr, &model_content_size, 11, true,
|
||||||
verbose, true, true, true, nullptr,
|
verbose, true, true, true, ops.data(),
|
||||||
0, "tensorrt",
|
1, "tensorrt",
|
||||||
&calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) {
|
&calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) {
|
||||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
Reference in New Issue
Block a user