diff --git a/cmake/paddle_inference.cmake b/cmake/paddle_inference.cmake index 43c726a23..842d6298a 100644 --- a/cmake/paddle_inference.cmake +++ b/cmake/paddle_inference.cmake @@ -69,7 +69,7 @@ else() else() set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-${PADDLEINFERENCE_VERSION}.tgz") if(WITH_GPU) - set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-gpu-${PADDLEINFERENCE_VERSION}.tgz") + set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-gpu-trt-${PADDLEINFERENCE_VERSION}.tgz") endif() endif() endif() diff --git a/fastdeploy/backends/paddle/paddle_backend.cc b/fastdeploy/backends/paddle/paddle_backend.cc index b1d9282e1..c8e3ff44f 100644 --- a/fastdeploy/backends/paddle/paddle_backend.cc +++ b/fastdeploy/backends/paddle/paddle_backend.cc @@ -16,23 +16,43 @@ namespace fastdeploy { -void PaddleBackend::BuildOption(const PaddleBackendOption& option, - const std::string& model_file) { +void PaddleBackend::BuildOption(const PaddleBackendOption& option) { if (option.use_gpu) { config_.EnableUseGpu(option.gpu_mem_init_size, option.gpu_id); + if (option.enable_trt) { +#ifdef ENABLE_TRT_BACKEND + auto precision = paddle_infer::PrecisionType::kFloat32; + if (option.trt_option.enable_fp16) { + precision = paddle_infer::PrecisionType::kHalf; + } + bool use_static = false; + if (option.trt_option.serialize_file != "") { + FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl; + use_static = true; + } + config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, 32, 3, precision, use_static); + std::map> max_shape; + std::map> min_shape; + std::map> opt_shape; + for (const auto& item : option.trt_option.min_shape) { + auto max_iter = option.trt_option.max_shape.find(item.first); + auto opt_iter = option.trt_option.opt_shape.find(item.first); + FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str()); + FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str()); + max_shape[item.first].assign(max_iter->second.begin(), max_iter->second.end()); + opt_shape[item.first].assign(opt_iter->second.begin(), opt_iter->second.end()); + min_shape[item.first].assign(item.second.begin(), item.second.end()); + } + if (min_shape.size() > 0) { + config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape); + } +#else + FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl; +#endif + } } else { config_.DisableGpu(); if (option.enable_mkldnn) { - config_.EnableMKLDNN(); - std::string contents; - if (!ReadBinaryFromFile(model_file, &contents)) { - return; - } - auto reader = - paddle2onnx::PaddleReader(contents.c_str(), contents.size()); - if (reader.is_quantize_model) { - config_.EnableMkldnnInt8(); - } config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size); } } @@ -62,28 +82,48 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file, return false; } config_.SetModel(model_file, params_file); - BuildOption(option, model_file); + BuildOption(option); + + // The input/output information get from predictor is not right, use PaddleReader instead now + std::string contents; + if (!ReadBinaryFromFile(model_file, &contents)) { + return false; + } + auto reader = + paddle2onnx::PaddleReader(contents.c_str(), contents.size()); + + // If it's a quantized model, and use cpu with mkldnn, automaticaly switch to int8 mode + if (reader.is_quantize_model) { + if (option.use_gpu) { + FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl; + } + if (option.enable_mkldnn) { + config_.EnableMkldnnInt8(); + } else { + FDWARNING << "The loaded model is a quantized model, while inference on CPU, please enable MKLDNN to get better performance." << std::endl; + } + } + + inputs_desc_.resize(reader.num_inputs); + for (int i = 0; i < reader.num_inputs; ++i) { + std::string name(reader.inputs[i].name); + std::vector shape( + reader.inputs[i].shape, + reader.inputs[i].shape + reader.inputs[i].rank); + inputs_desc_[i].name = name; + inputs_desc_[i].shape.assign(shape.begin(), shape.end()); + inputs_desc_[i].dtype = ReaderDataTypeToFD(reader.inputs[i].dtype); + } + outputs_desc_.resize(reader.num_outputs); + for (int i = 0; i < reader.num_outputs; ++i) { + std::string name(reader.outputs[i].name); + std::vector shape(reader.outputs[i].shape, reader.outputs[i].shape + reader.outputs[i].rank); + outputs_desc_[i].name = name; + outputs_desc_[i].shape.assign(shape.begin(), shape.end()); + outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype); + } + predictor_ = paddle_infer::CreatePredictor(config_); - std::vector input_names = predictor_->GetInputNames(); - std::vector output_names = predictor_->GetOutputNames(); - for (size_t i = 0; i < input_names.size(); ++i) { - auto handle = predictor_->GetInputHandle(input_names[i]); - TensorInfo info; - auto shape = handle->shape(); - info.shape.assign(shape.begin(), shape.end()); - info.dtype = PaddleDataTypeToFD(handle->type()); - info.name = input_names[i]; - inputs_desc_.emplace_back(info); - } - for (size_t i = 0; i < output_names.size(); ++i) { - auto handle = predictor_->GetOutputHandle(output_names[i]); - TensorInfo info; - auto shape = handle->shape(); - info.shape.assign(shape.begin(), shape.end()); - info.dtype = PaddleDataTypeToFD(handle->type()); - info.name = output_names[i]; - outputs_desc_.emplace_back(info); - } initialized_ = true; return true; } @@ -131,4 +171,4 @@ bool PaddleBackend::Infer(std::vector& inputs, return true; } -} // namespace fastdeploy \ No newline at end of file +} // namespace fastdeploy diff --git a/fastdeploy/backends/paddle/paddle_backend.h b/fastdeploy/backends/paddle/paddle_backend.h index b14a0e27c..0d59a8a33 100755 --- a/fastdeploy/backends/paddle/paddle_backend.h +++ b/fastdeploy/backends/paddle/paddle_backend.h @@ -20,9 +20,15 @@ #include #include "fastdeploy/backends/backend.h" +#ifdef ENABLE_PADDLE_FRONTEND #include "paddle2onnx/converter.h" +#endif #include "paddle_inference_api.h" // NOLINT +#ifdef ENABLE_TRT_BACKEND +#include "fastdeploy/backends/tensorrt/trt_backend.h" +#endif + namespace fastdeploy { struct PaddleBackendOption { @@ -35,6 +41,11 @@ struct PaddleBackendOption { bool enable_log_info = false; + bool enable_trt = false; +#ifdef ENABLE_TRT_BACKEND + TrtBackendOption trt_option; +#endif + int mkldnn_cache_size = 1; int cpu_thread_num = 8; // initialize memory size(MB) for GPU @@ -58,18 +69,21 @@ void CopyTensorToCpu(std::unique_ptr& tensor, // Convert data type from paddle inference to fastdeploy FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype); +// Convert data type from paddle2onnx::PaddleReader to fastdeploy +FDDataType ReaderDataTypeToFD(int32_t dtype); + class PaddleBackend : public BaseBackend { public: PaddleBackend() {} virtual ~PaddleBackend() = default; - void BuildOption(const PaddleBackendOption& option, - const std::string& model_file); + void BuildOption(const PaddleBackendOption& option); bool InitFromPaddle( const std::string& model_file, const std::string& params_file, const PaddleBackendOption& option = PaddleBackendOption()); - bool Infer(std::vector& inputs, std::vector* outputs) override; + bool Infer(std::vector& inputs, + std::vector* outputs) override; int NumInputs() const override { return inputs_desc_.size(); } diff --git a/fastdeploy/backends/paddle/util.cc b/fastdeploy/backends/paddle/util.cc index efd07b3c5..498561791 100644 --- a/fastdeploy/backends/paddle/util.cc +++ b/fastdeploy/backends/paddle/util.cc @@ -89,4 +89,26 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) { return fd_dtype; } +FDDataType ReaderDataTypeToFD(int32_t dtype) { + auto fd_dtype = FDDataType::FP32; + if (dtype == 0) { + fd_dtype = FDDataType::FP32; + } else if (dtype == 1) { + fd_dtype = FDDataType::FP64; + } else if (dtype == 2) { + fd_dtype = FDDataType::UINT8; + } else if (dtype == 3) { + fd_dtype = FDDataType::INT8; + } else if (dtype == 4) { + fd_dtype = FDDataType::INT32; + } else if (dtype == 5) { + fd_dtype = FDDataType::INT64; + } else if (dtype == 6) { + fd_dtype = FDDataType::FP16; + } else { + FDASSERT(false, "Unexpected data type: %d while call ReaderDataTypeToFD in PaddleBackend.", dtype); + } + return fd_dtype; +} + } // namespace fastdeploy diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 685b44222..092fc9ebb 100755 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -39,6 +39,7 @@ void BindRuntime(pybind11::module& m) { .def("set_lite_power_mode", &RuntimeOption::SetLitePowerMode) .def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape) .def("set_trt_max_workspace_size", &RuntimeOption::SetTrtMaxWorkspaceSize) + .def("enable_paddle_to_trt", &RuntimeOption::EnablePaddleToTrt) .def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16) .def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16) .def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile) diff --git a/fastdeploy/runtime.cc b/fastdeploy/runtime.cc index 44cd5e9cb..8df9d6548 100755 --- a/fastdeploy/runtime.cc +++ b/fastdeploy/runtime.cc @@ -258,6 +258,17 @@ void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; } void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; } +void RuntimeOption::EnablePaddleToTrt() { + FDASSERT(backend == Backend::TRT, "Should call UseTrtBackend() before call EnablePaddleToTrt()."); +#ifdef ENABLE_PADDLE_BACKEND + FDINFO << "While using TrtBackend with EnablePaddleToTrt, FastDeploy will change to use Paddle Inference Backend." << std::endl; + backend = Backend::PDINFER; + pd_enable_trt = true; +#else + FDASSERT(false, "While using TrtBackend with EnablePaddleToTrt, require the FastDeploy is compiled with Paddle Inference Backend, please rebuild your FastDeploy."); +#endif +} + void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) { FDASSERT(size > 0, "Parameter size must greater than 0."); pd_mkldnn_cache_size = size; @@ -406,6 +417,21 @@ void Runtime::CreatePaddleBackend() { pd_option.gpu_id = option.device_id; pd_option.delete_pass_names = option.pd_delete_pass_names; pd_option.cpu_thread_num = option.cpu_thread_num; +#ifdef ENABLE_TRT_BACKEND + if (pd_option.use_gpu && option.pd_enable_trt) { + pd_option.enable_trt = true; + auto trt_option = TrtBackendOption(); + trt_option.gpu_id = option.device_id; + trt_option.enable_fp16 = option.trt_enable_fp16; + trt_option.max_batch_size = option.trt_max_batch_size; + trt_option.max_workspace_size = option.trt_max_workspace_size; + trt_option.max_shape = option.trt_max_shape; + trt_option.min_shape = option.trt_min_shape; + trt_option.opt_shape = option.trt_opt_shape; + trt_option.serialize_file = option.trt_serialize_file; + pd_option.trt_option = trt_option; + } +#endif FDASSERT(option.model_format == ModelFormat::PADDLE, "PaddleBackend only support model format of ModelFormat::PADDLE."); backend_ = utils::make_unique(); diff --git a/fastdeploy/runtime.h b/fastdeploy/runtime.h index 26d565598..0cea2f026 100755 --- a/fastdeploy/runtime.h +++ b/fastdeploy/runtime.h @@ -126,6 +126,11 @@ struct FASTDEPLOY_DECL RuntimeOption { /// Set mkldnn switch while using Paddle Inference as inference backend void SetPaddleMKLDNN(bool pd_mkldnn = true); + /* + * @brief If TensorRT backend is used, EnablePaddleToTrt will change to use Paddle Inference backend, and use its integrated TensorRT instead. + */ + void EnablePaddleToTrt(); + /** * @brief Delete pass by name while using Paddle Inference as inference backend, this can be called multiple times to delete a set of passes */ @@ -214,6 +219,7 @@ struct FASTDEPLOY_DECL RuntimeOption { // ======Only for Paddle Backend===== bool pd_enable_mkldnn = true; bool pd_enable_log_info = false; + bool pd_enable_trt = false; int pd_mkldnn_cache_size = 1; std::vector pd_delete_pass_names; diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 901ef953c..aaba6abb3 100755 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -217,6 +217,11 @@ class RuntimeOption: """ return self._option.disable_trt_fp16() + def enable_paddle_to_trt(self): + """While using TensorRT backend, enable_paddle_to_trt() will change to use Paddle Inference backend, and use its integrated TensorRT instead. + """ + return self._option.enable_paddle_to_trt() + def set_trt_max_workspace_size(self, trt_max_workspace_size): """Set max workspace size while using TensorRT backend. """