Files
FastDeploy/fastdeploy/backends/tensorrt/ops/adaptive_pool2d.h
yeliang2258 d14828cb18 [Backend] Add AdaptivePool2d for TensorRT plugin (#668)
* add adaptivepool2d for tensorrt plugin

* update code

* update code

* update code to fix bug
2022-11-25 17:36:59 +08:00

113 lines
3.9 KiB
C++
Executable File

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
#include "common.h" // NOLINT
namespace fastdeploy {
class AdaptivePool2d : public BasePlugin {
public:
AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
AdaptivePool2d(const void* buffer, size_t length);
~AdaptivePool2d() override = default;
int getNbOutputs() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex,
const nvinfer1::DimsExprs* inputs,
int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
nvinfer1::DataType getOutputDataType(
int index,
const nvinfer1::DataType* inputType,
int nbInputs) const noexcept override;
bool supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs,
int nbOutputs) noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) noexcept override;
void destroy() noexcept override;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
private:
std::vector<int32_t> output_size_;
std::string pooling_type_;
};
class AdaptivePool2dPluginCreator : public BaseCreator {
public:
AdaptivePool2dPluginCreator();
~AdaptivePool2dPluginCreator() override = default;
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override;
nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name,
const void* serialData,
size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::vector<int32_t> output_size_;
std::string pooling_type_;
};
REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
} // namespace fastdeploy