diff --git a/cmake/paddle2onnx.cmake b/cmake/paddle2onnx.cmake index aec75dfdc..14b56e9f3 100755 --- a/cmake/paddle2onnx.cmake +++ b/cmake/paddle2onnx.cmake @@ -27,7 +27,7 @@ set(PADDLE2ONNX_LIB_DIR set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}" "${PADDLE2ONNX_LIB_DIR}") -include_directories(${PADDLE2ONNX_INC_DIR}) +include_directories(BEFORE ${PADDLE2ONNX_INC_DIR}) if(WIN32) set(PADDLE2ONNX_COMPILE_LIB "${PADDLE2ONNX_INSTALL_DIR}/lib/paddle2onnx.lib" @@ -45,7 +45,7 @@ endif(WIN32) if (NOT PADDLE2ONNX_URL) # Use default paddle2onnx url if custom url is not setting set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/") - set(PADDLE2ONNX_VERSION "1.0.7") + set(PADDLE2ONNX_VERSION "1.0.8rc") if(WIN32) set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip") if(NOT CMAKE_CL_64) diff --git a/fastdeploy/runtime/backends/ort/option.h b/fastdeploy/runtime/backends/ort/option.h index 155bcf908..4bc0874c7 100755 --- a/fastdeploy/runtime/backends/ort/option.h +++ b/fastdeploy/runtime/backends/ort/option.h @@ -48,5 +48,10 @@ struct OrtBackendOption { void* external_stream_ = nullptr; /// Use fp16 to infer bool enable_fp16 = false; + + std::vector ort_disabled_ops_{}; + void DisableOrtFP16OpTypes(const std::vector& ops) { + ort_disabled_ops_.insert(ort_disabled_ops_.end(), ops.begin(), ops.end()); + } }; } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/ort/option_pybind.cc b/fastdeploy/runtime/backends/ort/option_pybind.cc index 15ef2eeb0..6eb4dbd14 100644 --- a/fastdeploy/runtime/backends/ort/option_pybind.cc +++ b/fastdeploy/runtime/backends/ort/option_pybind.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fastdeploy/runtime/backends/ort/option.h" #include "fastdeploy/pybind/main.h" +#include "fastdeploy/runtime/backends/ort/option.h" namespace fastdeploy { @@ -29,7 +29,9 @@ void BindOrtOption(pybind11::module& m) { .def_readwrite("execution_mode", &OrtBackendOption::execution_mode) .def_readwrite("device", &OrtBackendOption::device) .def_readwrite("device_id", &OrtBackendOption::device_id) - .def_readwrite("enable_fp16", &OrtBackendOption::enable_fp16); + .def_readwrite("enable_fp16", &OrtBackendOption::enable_fp16) + .def("disable_ort_fp16_op_types", + &OrtBackendOption::DisableOrtFP16OpTypes); } } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/ort/ort_backend.cc b/fastdeploy/runtime/backends/ort/ort_backend.cc index ba874409f..987bf1882 100644 --- a/fastdeploy/runtime/backends/ort/ort_backend.cc +++ b/fastdeploy/runtime/backends/ort/ort_backend.cc @@ -180,16 +180,26 @@ bool OrtBackend::InitFromPaddle(const std::string& model_buffer, strcpy(ops[0].export_op_name, "MultiClassNMS"); strcpy(ops[1].op_name, "pool2d"); strcpy(ops[1].export_op_name, "AdaptivePool2d"); + converted_to_fp16 = option.enable_fp16; + std::vector disable_fp16_ops; + for (auto i = 0; i < option.ort_disabled_ops_.size(); i++) { + auto one_type = option.ort_disabled_ops_[i]; + char* charStr = new char[one_type.size() + 1]; + std::strcpy(charStr, one_type.c_str()); + disable_fp16_ops.push_back(charStr); + } if (!paddle2onnx::Export( model_buffer.c_str(), model_buffer.size(), params_buffer.c_str(), params_buffer.size(), &model_content_ptr, &model_content_size, 11, true, verbose, true, true, true, ops.data(), 2, "onnxruntime", - nullptr, 0, "", &save_external, false)) { + nullptr, 0, "", &save_external, option.enable_fp16, + disable_fp16_ops.data(), option.ort_disabled_ops_.size())) { FDERROR << "Error occured while export PaddlePaddle to ONNX format." << std::endl; return false; } + std::string onnx_model_proto(model_content_ptr, model_content_ptr + model_content_size); delete[] model_content_ptr; @@ -219,7 +229,7 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file, return false; } std::string onnx_model_buffer; - if (option.enable_fp16) { + if (!converted_to_fp16 && option.enable_fp16) { if (option.device == Device::CPU) { FDWARNING << "Turning on FP16 on CPU may result in slower inference." << std::endl; diff --git a/fastdeploy/runtime/backends/ort/ort_backend.h b/fastdeploy/runtime/backends/ort/ort_backend.h index 4b80d0626..92627a292 100755 --- a/fastdeploy/runtime/backends/ort/ort_backend.h +++ b/fastdeploy/runtime/backends/ort/ort_backend.h @@ -78,6 +78,9 @@ class OrtBackend : public BaseBackend { // the ONNX model file name, // when ONNX is bigger than 2G, we will set this name std::string model_file_name; + // recored if the model has been converted to fp16 + bool converted_to_fp16 = false; + #ifndef NON_64_PLATFORM Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle"); #endif diff --git a/serving/src/fastdeploy_runtime.cc b/serving/src/fastdeploy_runtime.cc index 0372a5b5a..fbf08fb68 100644 --- a/serving/src/fastdeploy_runtime.cc +++ b/serving/src/fastdeploy_runtime.cc @@ -388,6 +388,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) } else if (value_string == "pd_fp16") { // TODO(liqi): paddle inference don't currently have interface // for fp16. + } else if (value_string == "ort_fp16") { + runtime_options_->ort_option.enable_fp16 = true; } // } else if( param_key == "max_batch_size") { // THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue( @@ -412,14 +414,15 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) if (use_paddle) { // runtime_options_->EnablePaddleToTrt(); runtime_options_->UsePaddleInferBackend(); - runtime_options_->paddle_infer_option.enable_trt = true; + runtime_options_->paddle_infer_option.enable_trt = true; } } else if (param_key == "use_paddle_trt") { // Use new option setting policy to set paddle_trt backend bool use_paddle_trt; THROW_IF_BACKEND_MODEL_ERROR( ParseBoolValue(value_string, &use_paddle_trt)); - runtime_options_->paddle_infer_option.enable_trt = use_paddle_trt; + runtime_options_->paddle_infer_option.enable_trt = + use_paddle_trt; } else if (param_key == "use_paddle_log") { bool use_paddle_log; THROW_IF_BACKEND_MODEL_ERROR( @@ -436,6 +439,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model) SplitStringByDelimiter(value_string, ' ', &disable_trt_ops); runtime_options_->paddle_infer_option.DisableTrtOps( disable_trt_ops); + } else if (param_key == "disable_ort_fp16_op_types") { + std::vector disable_ort_fp16_op_types; + SplitStringByDelimiter(value_string, ' ', + &disable_ort_fp16_op_types); + runtime_options_->ort_option.DisableOrtFP16OpTypes( + disable_ort_fp16_op_types); } else if (param_key == "delete_passes") { std::vector delete_passes; SplitStringByDelimiter(value_string, ' ', &delete_passes);