diff --git a/CMakeLists.txt b/CMakeLists.txt index 3aee32349..2675b1f97 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -164,7 +164,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc) file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc) file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu) -file(GLOB_RECURSE DEPLOY_ORT_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cu) +file(GLOB_RECURSE DEPLOY_OP_CUDA_KERNEL_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/op_cuda_kernels/*.cu) file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc) file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc) file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc) @@ -202,7 +202,7 @@ if(ENABLE_ORT_BACKEND) include(${PROJECT_SOURCE_DIR}/cmake/onnxruntime.cmake) list(APPEND DEPEND_LIBS external_onnxruntime) if(WITH_GPU) - list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ORT_CUDA_SRCS}) + list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS}) endif() endif() @@ -361,6 +361,7 @@ if(ENABLE_TRT_BACKEND) find_library(TRT_ONNX_LIB nvonnxparser ${TRT_LIB_DIR} NO_DEFAULT_PATH) find_library(TRT_PLUGIN_LIB nvinfer_plugin ${TRT_LIB_DIR} NO_DEFAULT_PATH) list(APPEND DEPEND_LIBS ${TRT_INFER_LIB} ${TRT_ONNX_LIB} ${TRT_PLUGIN_LIB}) + list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_OP_CUDA_KERNEL_SRCS}) if(NOT BUILD_ON_JETSON) if(NOT EXISTS "${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt") diff --git a/fastdeploy/backends/ort/ops/adaptive_pooling.cu b/fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.cu similarity index 83% rename from fastdeploy/backends/ort/ops/adaptive_pooling.cu rename to fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.cu index 607279267..2fa63f36d 100755 --- a/fastdeploy/backends/ort/ops/adaptive_pooling.cu +++ b/fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.cu @@ -1,15 +1,12 @@ -#include "adaptive_pool2d.h" -#include -#include -#include -#include -#include -#include +#include "adaptive_pool2d_kernel.h" + namespace fastdeploy { __global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) { int position = blockDim.x * blockIdx.x + threadIdx.x; - if (position >= edge) return; + if (position >= edge) { + return; + } int offset = floorf(float(position) / out_bc_offset); int h = floorf(float(position % out_bc_offset) / ow); int w = (position % out_bc_offset) % ow; @@ -17,18 +14,18 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int out_b int hend = ceilf(static_cast((h + 1) * ih) / oh); int wstart = floorf(static_cast(w * iw) / ow); int wend = ceilf(static_cast((w + 1) * iw) / ow); - if(is_avg){ + if(is_avg) { out[position] = 0.0; - }else{ + } else { out[position] = in[offset * in_bc_offset + hstart * iw + wstart]; } for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { int input_idx = h * iw + w; - if(is_avg){ + if(is_avg) { out[position] = out[position] + in[offset * in_bc_offset + input_idx]; - }else{ - out[position] = max(out[position], in[offset * in_bc_offset + input_idx]); + } else { + out[position] = max(out[position], in[offset * in_bc_offset + input_idx]); } } } @@ -40,7 +37,7 @@ void CudaAdaptivePool(const std::vector& input_dims, const std::vector< int out_bc_offset = output_dims[2] * output_dims[3]; int in_bc_offset = input_dims[2] * input_dims[3]; int jobs = 1; - for(int i : output_dims){ + for(int i : output_dims) { jobs *= i; } bool is_avg = pooling_type == "avg"; diff --git a/fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h b/fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h new file mode 100755 index 000000000..3e68908ed --- /dev/null +++ b/fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h @@ -0,0 +1,35 @@ + +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace fastdeploy { + +void CudaAdaptivePool(const std::vector& input_dims, + const std::vector& output_dims, + float* output, + const float* input, + void* compute_stream, + const std::string& pooling_type); + + +} // namespace fastdeploy diff --git a/fastdeploy/backends/ort/ops/adaptive_pool2d.cc b/fastdeploy/backends/ort/ops/adaptive_pool2d.cc index 43fd380fe..7b4ec7ad4 100755 --- a/fastdeploy/backends/ort/ops/adaptive_pool2d.cc +++ b/fastdeploy/backends/ort/ops/adaptive_pool2d.cc @@ -14,14 +14,9 @@ #ifndef NON_64_PLATFORM -#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h" -#include -#include -#include "fastdeploy/core/fd_tensor.h" -#include "fastdeploy/utils/utils.h" +#include "adaptive_pool2d.h" namespace fastdeploy { - struct OrtTensorDimensions : std::vector { OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); diff --git a/fastdeploy/backends/ort/ops/adaptive_pool2d.h b/fastdeploy/backends/ort/ops/adaptive_pool2d.h index 2912efa72..556ca033b 100755 --- a/fastdeploy/backends/ort/ops/adaptive_pool2d.h +++ b/fastdeploy/backends/ort/ops/adaptive_pool2d.h @@ -16,19 +16,19 @@ #include #include +#include +#include +#include "fastdeploy/core/fd_tensor.h" +#include "fastdeploy/utils/utils.h" #ifndef NON_64_PLATFORM #include "onnxruntime_cxx_api.h" // NOLINT -namespace fastdeploy { #ifdef WITH_GPU -void CudaAdaptivePool(const std::vector& input_dims, - const std::vector& output_dims, - float* output, - const float* input, - void* compute_stream, - const std::string& pooling_type); +#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h" #endif + +namespace fastdeploy { struct AdaptivePool2dKernel { protected: std::string pooling_type_ = "avg"; diff --git a/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.cc b/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.cc new file mode 100755 index 000000000..bfec5e356 --- /dev/null +++ b/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "adaptive_pool2d.h" + +namespace fastdeploy { + +nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{}; +std::vector AdaptivePool2dPluginCreator::mPluginAttributes; + +pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n, const void* input, void* output); + +AdaptivePool2d::AdaptivePool2d(std::vector output_size, std::string pooling_type) { + output_size_ = output_size; + pooling_type_ = pooling_type; +} + +AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) { + const char *d = reinterpret_cast(buffer), *a = d; + output_size_.resize(4); + for(int64_t i =0 ; i < 4; i++){ + output_size_[i] =read(d); + } + if(read(d) == 0){ + pooling_type_ = "avg"; + }else{ + pooling_type_ = "max"; + } + FDASSERT(d == a + length, "deserialize failed."); +} + +int AdaptivePool2d::getNbOutputs() const noexcept { + return 1; +} + +nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs* inputs, + int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept { + try { + nvinfer1::DimsExprs output(inputs[0]); + output.d[2] = exprBuilder.constant(static_cast(output_size_[2])); + output.d[3] = exprBuilder.constant(static_cast(output_size_[3])); + return output; + } + catch (const std::exception& e) { + FDASSERT(false, "getOutputDimensions failed: %s.",e.what()); + } + return nvinfer1::DimsExprs{}; +} + +int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { + if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) { + return -1; + } + auto const* data = static_cast(inputs[0]); + auto* result = static_cast(outputs[0]); + int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2]* outputDesc[0].dims.d[3]; + std::vector input_size, output_size; + for(int i =0; i< 4; i++){ + input_size.push_back(inputDesc[0].dims.d[i]); + output_size.push_back(outputDesc[0].dims.d[i]); + } + CudaAdaptivePool(input_size, output_size, result, data, stream, pooling_type_); + return cudaPeekAtLastError(); +} + +size_t AdaptivePool2d::getSerializationSize() const noexcept { + return 5 * sizeof(int32_t) ; +} + +void AdaptivePool2d::serialize(void* buffer) const noexcept { + char *d = reinterpret_cast(buffer), *a = d; + for(int64_t i=0; i< 4; i++){ + write(d, output_size_[i]); + } + int32_t pooling_type_val = 0; + if(pooling_type_ != "avg"){ + pooling_type_val = 1; + } + write(d, pooling_type_val); + FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()"); +} + +nvinfer1::DataType AdaptivePool2d::getOutputDataType( + int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept { + return inputType[0]; +} + +bool AdaptivePool2d::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept { + return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR); +} + +int AdaptivePool2d::initialize() noexcept { + return 0; +} + +void AdaptivePool2d::terminate() noexcept { + return; +} + +size_t AdaptivePool2d::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept { + return 0; +} + +const char* AdaptivePool2d::getPluginType() const noexcept { + return "AdaptivePool2d"; +} + +const char* AdaptivePool2d::getPluginVersion() const noexcept { + return "1"; +} + +void AdaptivePool2d::destroy() noexcept { + return; +} +void AdaptivePool2d::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept { + return; +} +nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept { + try{ + nvinfer1::IPluginV2DynamicExt* plugin = new AdaptivePool2d(output_size_, pooling_type_); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + catch (std::exception const& e){ + FDASSERT(false, "clone failed: %s.",e.what()); + } + return nullptr; +} + +AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4)); + mPluginAttributes.emplace_back(nvinfer1::PluginField("pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char* AdaptivePool2dPluginCreator::getPluginName() const noexcept { + return "AdaptivePool2d"; +} + +const char* AdaptivePool2dPluginCreator::getPluginVersion() const noexcept { + return "1"; +} + +const nvinfer1::PluginFieldCollection* AdaptivePool2dPluginCreator::getFieldNames() noexcept { + return &mFC; +} + +nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept { + try{ + const nvinfer1::PluginField* fields = fc->fields; + auto const dims = static_cast(fields[0].data); + output_size_.resize(4); + for(int64_t i = 0; i < 4; i++){ + output_size_[i] = dims[i]; + } + + const char* pooling_type_ptr = (static_cast(fields[1].data)); + std::string pooling_type(pooling_type_ptr, 3); + pooling_type_ = pooling_type; + return new AdaptivePool2d(output_size_, pooling_type_); + } + catch (std::exception const& e){ + FDASSERT(false, "createPlugin failed: %s.",e.what()); + } + return nullptr; +} + +nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept { + try{ + return new AdaptivePool2d(serialData, serialLength); + } + catch (std::exception const& e){ + FDASSERT(false, "deserializePlugin failed: %s.",e.what()); + } + return nullptr; +} + +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h b/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h new file mode 100755 index 000000000..2e6e45e2c --- /dev/null +++ b/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h" +#include "common.h" // NOLINT + +namespace fastdeploy { + +class AdaptivePool2d : public BasePlugin { + public: + AdaptivePool2d(std::vector output_size, std::string pooling_type); + + AdaptivePool2d(const void* buffer, size_t length); + + ~AdaptivePool2d() override = default; + + int getNbOutputs() const noexcept override; + + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + + nvinfer1::DataType getOutputDataType( + int index, + const nvinfer1::DataType* inputType, + int nbInputs) const noexcept override; + + bool supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) noexcept override; + + int initialize() noexcept override; + + void terminate() noexcept override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const noexcept override; + + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + + size_t getSerializationSize() const noexcept override; + + void serialize(void* buffer) const noexcept override; + + const char* getPluginType() const noexcept override; + + const char* getPluginVersion() const noexcept override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) noexcept override; + void destroy() noexcept override; + + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + + private: + std::vector output_size_; + std::string pooling_type_; +}; + +class AdaptivePool2dPluginCreator : public BaseCreator { + public: + AdaptivePool2dPluginCreator(); + + ~AdaptivePool2dPluginCreator() override = default; + + const char* getPluginName() const noexcept override; + + const char* getPluginVersion() const noexcept override; + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + + nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + + nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) noexcept override; + + private: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::vector output_size_; + std::string pooling_type_; +}; + +REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator); + +} // namespace fastdeploy diff --git a/fastdeploy/backends/tensorrt/ops/common.h b/fastdeploy/backends/tensorrt/ops/common.h new file mode 100755 index 000000000..beada71ff --- /dev/null +++ b/fastdeploy/backends/tensorrt/ops/common.h @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "NvInferPlugin.h" +#include "NvInferRuntimeCommon.h" +#include "fastdeploy/utils/utils.h" +#include +#include +#include +#include +#include +#include + +namespace fastdeploy { + +class BasePlugin : public nvinfer1::IPluginV2DynamicExt { + protected: + void setPluginNamespace(const char* libNamespace) noexcept override { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return mNamespace.c_str(); + } + + std::string mNamespace; +}; + +class BaseCreator : public nvinfer1::IPluginCreator { + public: + void setPluginNamespace(const char* libNamespace) noexcept override { + mNamespace = libNamespace; + } + + const char* getPluginNamespace() const noexcept override { + return mNamespace.c_str(); + } + + protected: + std::string mNamespace; +}; + +typedef enum { + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 +} pluginStatus_t; + +// Write values into buffer +template +void write(char*& buffer, const T& val) { + std::memcpy(buffer, &val, sizeof(T)); + buffer += sizeof(T); +} + +// Read values from buffer +template +T read(const char*& buffer) { + T val{}; + std::memcpy(&val, buffer, sizeof(T)); + buffer += sizeof(T); + return val; +} + +} // namespace fastdeploy diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc index 82cb45464..5cdc266b6 100755 --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -124,14 +124,18 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file, option_ = option; #ifdef ENABLE_PADDLE_FRONTEND + std::vector ops; + ops.resize(1); + strcpy(ops[0].op_name, "pool2d"); + strcpy(ops[0].export_op_name, "AdaptivePool2d"); char* model_content_ptr; int model_content_size = 0; char* calibration_cache_ptr; int calibration_cache_size = 0; if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(), &model_content_ptr, &model_content_size, 11, true, - verbose, true, true, true, nullptr, - 0, "tensorrt", + verbose, true, true, true, ops.data(), + 1, "tensorrt", &calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) { FDERROR << "Error occured while export PaddlePaddle to ONNX format." << std::endl;