mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 18:11:00 +08:00
Remove tensorrt/common codes (#171)
This commit is contained in:
@@ -13,12 +13,17 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/tensorrt/trt_backend.h"
|
||||
#include <cstring>
|
||||
#include "NvInferSafeRuntime.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
#include "paddle2onnx/converter.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
FDTrtLogger* FDTrtLogger::logger = nullptr;
|
||||
|
||||
size_t TrtDataTypeSize(const nvinfer1::DataType& dtype) {
|
||||
if (dtype == nvinfer1::DataType::kFLOAT) {
|
||||
return sizeof(float);
|
||||
@@ -130,8 +135,8 @@ bool TrtBackend::InitFromTrt(const std::string& trt_engine_file,
|
||||
fin.seekg(0, std::ios::beg);
|
||||
fin.read(&(engine_buffer.at(0)), engine_buffer.size());
|
||||
fin.close();
|
||||
SampleUniquePtr<IRuntime> runtime{
|
||||
createInferRuntime(sample::gLogger.getTRTLogger())};
|
||||
FDUniquePtr<nvinfer1::IRuntime> runtime{
|
||||
nvinfer1::createInferRuntime(*FDTrtLogger::Get())};
|
||||
if (!runtime) {
|
||||
FDERROR << "Failed to call createInferRuntime()." << std::endl;
|
||||
return false;
|
||||
@@ -139,7 +144,7 @@ bool TrtBackend::InitFromTrt(const std::string& trt_engine_file,
|
||||
engine_ = std::shared_ptr<nvinfer1::ICudaEngine>(
|
||||
runtime->deserializeCudaEngine(engine_buffer.data(),
|
||||
engine_buffer.size()),
|
||||
samplesCommon::InferDeleter());
|
||||
FDInferDeleter());
|
||||
if (!engine_) {
|
||||
FDERROR << "Failed to call deserializeCudaEngine()." << std::endl;
|
||||
return false;
|
||||
@@ -320,10 +325,10 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
auto dtype = engine_->getBindingDataType(i);
|
||||
if (engine_->bindingIsInput(i)) {
|
||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||
inputs_buffer_[name] = DeviceBuffer(dtype);
|
||||
inputs_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
} else {
|
||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
|
||||
outputs_buffer_[name] = DeviceBuffer(dtype);
|
||||
outputs_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
}
|
||||
}
|
||||
bindings_.resize(num_binds);
|
||||
@@ -334,7 +339,7 @@ void TrtBackend::AllocateBufferInDynamicShape(
|
||||
for (const auto& item : inputs) {
|
||||
auto idx = engine_->getBindingIndex(item.name.c_str());
|
||||
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
||||
auto dims = sample::toDims(shape);
|
||||
auto dims = ToDims(shape);
|
||||
context_->setBindingDimensions(idx, dims);
|
||||
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {
|
||||
inputs_buffer_[item.name].resize(dims);
|
||||
@@ -357,7 +362,7 @@ void TrtBackend::AllocateBufferInDynamicShape(
|
||||
(*outputs)[ori_idx].shape.assign(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].name = outputs_desc_[i].name;
|
||||
(*outputs)[ori_idx].data.resize(volume(output_dims) *
|
||||
(*outputs)[ori_idx].data.resize(Volume(output_dims) *
|
||||
TrtDataTypeSize(outputs_desc_[i].dtype));
|
||||
if ((*outputs)[ori_idx].Nbytes() >
|
||||
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
|
||||
@@ -373,19 +378,19 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
1U << static_cast<uint32_t>(
|
||||
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
|
||||
|
||||
builder_ = SampleUniquePtr<nvinfer1::IBuilder>(
|
||||
nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
|
||||
builder_ = FDUniquePtr<nvinfer1::IBuilder>(
|
||||
nvinfer1::createInferBuilder(*FDTrtLogger::Get()));
|
||||
if (!builder_) {
|
||||
FDERROR << "Failed to call createInferBuilder()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
network_ = SampleUniquePtr<nvinfer1::INetworkDefinition>(
|
||||
network_ = FDUniquePtr<nvinfer1::INetworkDefinition>(
|
||||
builder_->createNetworkV2(explicitBatch));
|
||||
if (!network_) {
|
||||
FDERROR << "Failed to call createNetworkV2()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(
|
||||
auto config = FDUniquePtr<nvinfer1::IBuilderConfig>(
|
||||
builder_->createBuilderConfig());
|
||||
if (!config) {
|
||||
FDERROR << "Failed to call createBuilderConfig()." << std::endl;
|
||||
@@ -402,8 +407,8 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
}
|
||||
}
|
||||
|
||||
parser_ = SampleUniquePtr<nvonnxparser::IParser>(
|
||||
nvonnxparser::createParser(*network_, sample::gLogger.getTRTLogger()));
|
||||
parser_ = FDUniquePtr<nvonnxparser::IParser>(
|
||||
nvonnxparser::createParser(*network_, *FDTrtLogger::Get()));
|
||||
if (!parser_) {
|
||||
FDERROR << "Failed to call createParser()." << std::endl;
|
||||
return false;
|
||||
@@ -429,7 +434,7 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
// set min shape
|
||||
FDASSERT(profile->setDimensions(item.first.c_str(),
|
||||
nvinfer1::OptProfileSelector::kMIN,
|
||||
sample::toDims(item.second)),
|
||||
ToDims(item.second)),
|
||||
"[TrtBackend] Failed to set min_shape for input: %s in TrtBackend.", item.first.c_str());
|
||||
|
||||
// set optimization shape
|
||||
@@ -438,7 +443,7 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
"[TrtBackend] Cannot find input name: %s in TrtBackendOption::opt_shape.", item.first.c_str());
|
||||
FDASSERT(profile->setDimensions(item.first.c_str(),
|
||||
nvinfer1::OptProfileSelector::kOPT,
|
||||
sample::toDims(iter->second)),
|
||||
ToDims(iter->second)),
|
||||
"[TrtBackend] Failed to set opt_shape for input: %s in TrtBackend.", item.first.c_str());
|
||||
// set max shape
|
||||
iter = option.max_shape.find(item.first);
|
||||
@@ -446,21 +451,21 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
"[TrtBackend] Cannot find input name: %s in TrtBackendOption::max_shape.", item.first);
|
||||
FDASSERT(profile->setDimensions(item.first.c_str(),
|
||||
nvinfer1::OptProfileSelector::kMAX,
|
||||
sample::toDims(iter->second)),
|
||||
ToDims(iter->second)),
|
||||
"[TrtBackend] Failed to set max_shape for input: %s in TrtBackend.", item.first);
|
||||
}
|
||||
config->addOptimizationProfile(profile);
|
||||
}
|
||||
|
||||
SampleUniquePtr<IHostMemory> plan{
|
||||
FDUniquePtr<nvinfer1::IHostMemory> plan{
|
||||
builder_->buildSerializedNetwork(*network_, *config)};
|
||||
if (!plan) {
|
||||
FDERROR << "Failed to call buildSerializedNetwork()." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
SampleUniquePtr<IRuntime> runtime{
|
||||
createInferRuntime(sample::gLogger.getTRTLogger())};
|
||||
FDUniquePtr<nvinfer1::IRuntime> runtime{
|
||||
nvinfer1::createInferRuntime(*FDTrtLogger::Get())};
|
||||
if (!runtime) {
|
||||
FDERROR << "Failed to call createInferRuntime()." << std::endl;
|
||||
return false;
|
||||
@@ -468,7 +473,7 @@ bool TrtBackend::CreateTrtEngine(const std::string& onnx_model,
|
||||
|
||||
engine_ = std::shared_ptr<nvinfer1::ICudaEngine>(
|
||||
runtime->deserializeCudaEngine(plan->data(), plan->size()),
|
||||
samplesCommon::InferDeleter());
|
||||
FDInferDeleter());
|
||||
if (!engine_) {
|
||||
FDERROR << "Failed to call deserializeCudaEngine()." << std::endl;
|
||||
return false;
|
||||
|
Reference in New Issue
Block a user