mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-12 20:11:20 +08:00
[Other] Optimize code style (#1032)
* Optimize code * optimize code * optimize code * fix compile error
This commit is contained in:
@@ -134,9 +134,9 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
||||
int calibration_cache_size = 0;
|
||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||
&model_content_ptr, &model_content_size, 11, true,
|
||||
verbose, true, true, true, ops.data(),
|
||||
1, "tensorrt",
|
||||
&calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) {
|
||||
verbose, true, true, true, ops.data(), 1, "tensorrt",
|
||||
&calibration_cache_ptr, &calibration_cache_size, "",
|
||||
&save_external_)) {
|
||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||
<< std::endl;
|
||||
return false;
|
||||
@@ -152,11 +152,11 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
||||
calibration_str_ = calibration_str;
|
||||
delete[] calibration_cache_ptr;
|
||||
}
|
||||
if(save_external_){
|
||||
if (save_external_) {
|
||||
model_file_name_ = "model.onnx";
|
||||
std::fstream f(model_file_name_, std::ios::out);
|
||||
FDASSERT(f.is_open(), "Can not open file: %s to save model.",
|
||||
model_file_name_.c_str());
|
||||
model_file_name_.c_str());
|
||||
f << onnx_model_proto;
|
||||
f.close();
|
||||
return InitFromOnnx(model_file_name_, option, false);
|
||||
@@ -215,13 +215,14 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
outputs_desc_.resize(onnx_reader.num_outputs);
|
||||
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
|
||||
std::string name(onnx_reader.inputs[i].name);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
|
||||
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape +
|
||||
onnx_reader.inputs[i].rank);
|
||||
inputs_desc_[i].name = name;
|
||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
||||
inputs_desc_[i].original_dtype = ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
|
||||
inputs_desc_[i].original_dtype =
|
||||
ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
|
||||
auto info = ShapeRangeInfo(shape);
|
||||
info.name = name;
|
||||
auto iter_min = option.min_shape.find(name);
|
||||
@@ -237,9 +238,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
|
||||
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
|
||||
std::string name(onnx_reader.outputs[i].name);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
|
||||
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape +
|
||||
onnx_reader.outputs[i].rank);
|
||||
outputs_desc_[i].name = name;
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype =
|
||||
@@ -252,10 +253,10 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
stream_ = reinterpret_cast<cudaStream_t>(option_.external_stream_);
|
||||
} else {
|
||||
FDASSERT(cudaStreamCreate(&stream_) == 0,
|
||||
"[ERROR] Error occurs while calling cudaStreamCreate().");
|
||||
"[ERROR] Error occurs while calling cudaStreamCreate().");
|
||||
}
|
||||
|
||||
if(save_external_){
|
||||
if (save_external_) {
|
||||
onnx_content.clear();
|
||||
onnx_content = model_file_name_;
|
||||
}
|
||||
@@ -283,8 +284,7 @@ int TrtBackend::ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs) {
|
||||
}
|
||||
|
||||
bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||
if (inputs.size() != NumInputs()) {
|
||||
FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
|
||||
<< "." << std::endl;
|
||||
@@ -297,7 +297,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
<< "TensorRT engine will be rebuilt once shape range information "
|
||||
"changed, this may take lots of time, you can set a proper shape "
|
||||
"range before loading model to avoid rebuilding process. refer "
|
||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/faq/"
|
||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
|
||||
"faq/"
|
||||
"tensorrt_tricks.md for more details."
|
||||
<< std::endl;
|
||||
BuildTrtEngine();
|
||||
@@ -314,38 +315,42 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
// if the final output tensor's dtype is different from the model output tensor's dtype,
|
||||
// then we need cast the data to the final output's dtype
|
||||
auto model_output_dtype = GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
|
||||
auto model_output_dtype =
|
||||
GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
|
||||
if ((*outputs)[i].dtype != model_output_dtype) {
|
||||
FDTensor output_tensor;
|
||||
output_tensor.SetExternalData((*outputs)[i].shape, model_output_dtype,
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||
Device::GPU);
|
||||
|
||||
casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype,
|
||||
(*outputs)[i].name, Device::GPU);
|
||||
function::CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_);
|
||||
if(!copy_to_fd) {
|
||||
(*outputs)[i].SetExternalData((*outputs)[i].shape, model_output_dtype,
|
||||
casted_output_tensors_[(*outputs)[i].name].MutableData(),
|
||||
Device::GPU, option_.gpu_id);
|
||||
output_tensor.SetExternalData(
|
||||
(*outputs)[i].shape, model_output_dtype,
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
|
||||
|
||||
casted_output_tensors_[(*outputs)[i].name].Resize(
|
||||
(*outputs)[i].shape, (*outputs)[i].dtype, (*outputs)[i].name,
|
||||
Device::GPU);
|
||||
function::CudaCast(output_tensor,
|
||||
&casted_output_tensors_[(*outputs)[i].name], stream_);
|
||||
if (!copy_to_fd) {
|
||||
(*outputs)[i].SetExternalData(
|
||||
(*outputs)[i].shape, model_output_dtype,
|
||||
casted_output_tensors_[(*outputs)[i].name].MutableData(),
|
||||
Device::GPU, option_.gpu_id);
|
||||
}
|
||||
} else {
|
||||
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
|
||||
(*outputs)[i].shape, model_output_dtype,
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
||||
Device::GPU);
|
||||
outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
|
||||
}
|
||||
}
|
||||
if (copy_to_fd) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
||||
casted_output_tensors_[(*outputs)[i].name].Data(),
|
||||
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
||||
stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
||||
FDASSERT(
|
||||
cudaMemcpyAsync((*outputs)[i].Data(),
|
||||
casted_output_tensors_[(*outputs)[i].name].Data(),
|
||||
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
||||
stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
||||
}
|
||||
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -356,10 +361,12 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
|
||||
std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
|
||||
for (size_t i = 0; i < inputs_desc_.size(); ++i) {
|
||||
inputs_original_dtype_map[inputs_desc_[i].name] = inputs_desc_[i].original_dtype;
|
||||
inputs_original_dtype_map[inputs_desc_[i].name] =
|
||||
inputs_desc_[i].original_dtype;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||
outputs_original_dtype_map[outputs_desc_[i].name] = outputs_desc_[i].original_dtype;
|
||||
outputs_original_dtype_map[outputs_desc_[i].name] =
|
||||
outputs_desc_[i].original_dtype;
|
||||
}
|
||||
|
||||
// Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_
|
||||
@@ -373,12 +380,18 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
auto shape = ToVec(engine_->getBindingDimensions(i));
|
||||
auto dtype = engine_->getBindingDataType(i);
|
||||
if (engine_->bindingIsInput(i)) {
|
||||
auto original_dtype = inputs_original_dtype_map.count(name) ? inputs_original_dtype_map[name] : GetFDDataType(dtype);
|
||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
auto original_dtype = inputs_original_dtype_map.count(name)
|
||||
? inputs_original_dtype_map[name]
|
||||
: GetFDDataType(dtype);
|
||||
inputs_desc_.emplace_back(
|
||||
TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
} else {
|
||||
auto original_dtype = outputs_original_dtype_map.count(name) ? outputs_original_dtype_map[name] : GetFDDataType(dtype);
|
||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
auto original_dtype = outputs_original_dtype_map.count(name)
|
||||
? outputs_original_dtype_map[name]
|
||||
: GetFDDataType(dtype);
|
||||
outputs_desc_.emplace_back(
|
||||
TrtValueInfo{name, shape, dtype, original_dtype});
|
||||
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||
casted_output_tensors_[name] = FDTensor();
|
||||
}
|
||||
@@ -391,8 +404,9 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
for (const auto& item : inputs) {
|
||||
// auto idx = engine_->getBindingIndex(item.name.c_str());
|
||||
auto iter = io_name_index_.find(item.name);
|
||||
FDASSERT(iter != io_name_index_.end(), "TRTBackend SetInputs not find name:%s", item.name.c_str());
|
||||
auto idx = iter->second;
|
||||
FDASSERT(iter != io_name_index_.end(),
|
||||
"TRTBackend SetInputs not find name:%s", item.name.c_str());
|
||||
auto idx = iter->second;
|
||||
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
||||
auto dims = ToDims(shape);
|
||||
context_->setBindingDimensions(idx, dims);
|
||||
@@ -424,9 +438,8 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
} else {
|
||||
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||
item.Data(),
|
||||
item.Nbytes(), cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
item.Data(), item.Nbytes(),
|
||||
cudaMemcpyHostToDevice, stream_) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
}
|
||||
}
|
||||
@@ -443,8 +456,10 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||
// auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
|
||||
auto idx_iter = io_name_index_.find(outputs_desc_[i].name);
|
||||
FDASSERT(idx_iter != io_name_index_.end(), "TRTBackend Outputs not find name:%s", outputs_desc_[i].name.c_str());
|
||||
auto idx = idx_iter->second;
|
||||
FDASSERT(idx_iter != io_name_index_.end(),
|
||||
"TRTBackend Outputs not find name:%s",
|
||||
outputs_desc_[i].name.c_str());
|
||||
auto idx = idx_iter->second;
|
||||
auto output_dims = context_->getBindingDimensions(idx);
|
||||
|
||||
// find the original index of output
|
||||
@@ -457,23 +472,22 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
||||
|
||||
// Allocate output buffer memory
|
||||
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||
|
||||
|
||||
// binding output buffer
|
||||
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
||||
|
||||
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
||||
|
||||
// set user's outputs info
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
if(copy_to_fd) {
|
||||
if (copy_to_fd) {
|
||||
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
||||
outputs_desc_[i].name);
|
||||
} else {
|
||||
(*outputs)[ori_idx].name = outputs_desc_[i].name;
|
||||
(*outputs)[ori_idx].SetExternalData(
|
||||
shape, outputs_desc_[i].original_dtype,
|
||||
bindings_[idx], Device::GPU,
|
||||
option_.gpu_id);
|
||||
shape, outputs_desc_[i].original_dtype, bindings_[idx], Device::GPU,
|
||||
option_.gpu_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -587,7 +601,8 @@ bool TrtBackend::BuildTrtEngine() {
|
||||
if (option_.serialize_file != "") {
|
||||
FDINFO << "Serialize TensorRTEngine to local file "
|
||||
<< option_.serialize_file << "." << std::endl;
|
||||
std::ofstream engine_file(option_.serialize_file.c_str(), std::ios::binary | std::ios::out);
|
||||
std::ofstream engine_file(option_.serialize_file.c_str(),
|
||||
std::ios::binary | std::ios::out);
|
||||
if (!engine_file) {
|
||||
FDERROR << "Failed to open " << option_.serialize_file << " to write."
|
||||
<< std::endl;
|
||||
@@ -628,10 +643,11 @@ bool TrtBackend::CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer) {
|
||||
return false;
|
||||
}
|
||||
bool model_parser;
|
||||
if(save_external_){
|
||||
model_parser=!parser_->parseFromFile(onnx_model_buffer.c_str(), 0);
|
||||
}else{
|
||||
model_parser = !parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size());
|
||||
if (save_external_) {
|
||||
model_parser = !parser_->parseFromFile(onnx_model_buffer.c_str(), 0);
|
||||
} else {
|
||||
model_parser =
|
||||
!parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size());
|
||||
}
|
||||
if (model_parser) {
|
||||
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
|
||||
@@ -665,7 +681,8 @@ bool TrtBackend::CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer) {
|
||||
"should be noticed that FastDeploy will rebuild the engine while "
|
||||
"new input shape is out of the collected shape range, this may "
|
||||
"bring some time consuming problem, refer "
|
||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/faq/"
|
||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
|
||||
"faq/"
|
||||
"tensorrt_tricks.md for more details."
|
||||
<< std::endl;
|
||||
initialized_ = true;
|
||||
@@ -721,27 +738,24 @@ std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
|
||||
return infos;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseBackend> TrtBackend::Clone(void *stream, int device_id) {
|
||||
std::unique_ptr<BaseBackend> TrtBackend::Clone(void* stream, int device_id) {
|
||||
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<TrtBackend>();
|
||||
auto casted_backend = dynamic_cast<TrtBackend*>(new_backend.get());
|
||||
if(device_id > 0 && device_id != option_.gpu_id) {
|
||||
if (device_id > 0 && device_id != option_.gpu_id) {
|
||||
auto clone_option = option_;
|
||||
clone_option.gpu_id = device_id;
|
||||
clone_option.external_stream_ = stream;
|
||||
if (option_.model_format == ModelFormat::ONNX) {
|
||||
FDASSERT(casted_backend->InitFromOnnx(option_.model_file, clone_option),
|
||||
"Clone model from ONNX failed while initialize TrtBackend.");
|
||||
"Clone model from ONNX failed while initialize TrtBackend.");
|
||||
} else {
|
||||
FDASSERT(casted_backend->InitFromPaddle(option_.model_file,
|
||||
option_.params_file, clone_option),
|
||||
"Clone model from Paddle failed while initialize TrtBackend.");
|
||||
FDASSERT(casted_backend->InitFromPaddle(
|
||||
option_.model_file, option_.params_file, clone_option),
|
||||
"Clone model from Paddle failed while initialize TrtBackend.");
|
||||
}
|
||||
FDWARNING << "The target device id:"
|
||||
<< device_id
|
||||
<< " is different from current device id:"
|
||||
<< option_.gpu_id
|
||||
<< ", cannot share memory with current engine."
|
||||
<< std::endl;
|
||||
FDWARNING << "The target device id:" << device_id
|
||||
<< " is different from current device id:" << option_.gpu_id
|
||||
<< ", cannot share memory with current engine." << std::endl;
|
||||
return new_backend;
|
||||
}
|
||||
cudaSetDevice(option_.gpu_id);
|
||||
@@ -750,12 +764,15 @@ std::unique_ptr<BaseBackend> TrtBackend::Clone(void *stream, int device_id) {
|
||||
casted_backend->stream_ = reinterpret_cast<cudaStream_t>(stream);
|
||||
} else {
|
||||
FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0,
|
||||
"[ERROR] Error occurs while clone calling cudaStreamCreate().");
|
||||
"[ERROR] Error occurs while clone calling cudaStreamCreate().");
|
||||
}
|
||||
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
|
||||
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
|
||||
casted_backend->outputs_order_.insert(outputs_order_.begin(), outputs_order_.end());
|
||||
casted_backend->shape_range_info_.insert(shape_range_info_.begin(), shape_range_info_.end());
|
||||
casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
|
||||
outputs_desc_.end());
|
||||
casted_backend->outputs_order_.insert(outputs_order_.begin(),
|
||||
outputs_order_.end());
|
||||
casted_backend->shape_range_info_.insert(shape_range_info_.begin(),
|
||||
shape_range_info_.end());
|
||||
casted_backend->engine_ = engine_;
|
||||
casted_backend->context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
|
||||
casted_backend->engine_->createExecutionContext());
|
||||
|
Reference in New Issue
Block a user