mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-09 10:30:37 +08:00
Add ORT fp16 support in server (#2069)
* add ort fp16 support in server * update paddle2onnx url * update ort fp16 api * add disable_ort_fp16_op_types in serving
This commit is contained in:
@@ -27,7 +27,7 @@ set(PADDLE2ONNX_LIB_DIR
|
|||||||
set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}"
|
set(CMAKE_BUILD_RPATH "${CMAKE_BUILD_RPATH}"
|
||||||
"${PADDLE2ONNX_LIB_DIR}")
|
"${PADDLE2ONNX_LIB_DIR}")
|
||||||
|
|
||||||
include_directories(${PADDLE2ONNX_INC_DIR})
|
include_directories(BEFORE ${PADDLE2ONNX_INC_DIR})
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
set(PADDLE2ONNX_COMPILE_LIB
|
set(PADDLE2ONNX_COMPILE_LIB
|
||||||
"${PADDLE2ONNX_INSTALL_DIR}/lib/paddle2onnx.lib"
|
"${PADDLE2ONNX_INSTALL_DIR}/lib/paddle2onnx.lib"
|
||||||
@@ -45,7 +45,7 @@ endif(WIN32)
|
|||||||
if (NOT PADDLE2ONNX_URL)
|
if (NOT PADDLE2ONNX_URL)
|
||||||
# Use default paddle2onnx url if custom url is not setting
|
# Use default paddle2onnx url if custom url is not setting
|
||||||
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
|
set(PADDLE2ONNX_URL_BASE "https://bj.bcebos.com/fastdeploy/third_libs/")
|
||||||
set(PADDLE2ONNX_VERSION "1.0.7")
|
set(PADDLE2ONNX_VERSION "1.0.8rc")
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
|
set(PADDLE2ONNX_FILE "paddle2onnx-win-x64-${PADDLE2ONNX_VERSION}.zip")
|
||||||
if(NOT CMAKE_CL_64)
|
if(NOT CMAKE_CL_64)
|
||||||
|
@@ -48,5 +48,10 @@ struct OrtBackendOption {
|
|||||||
void* external_stream_ = nullptr;
|
void* external_stream_ = nullptr;
|
||||||
/// Use fp16 to infer
|
/// Use fp16 to infer
|
||||||
bool enable_fp16 = false;
|
bool enable_fp16 = false;
|
||||||
|
|
||||||
|
std::vector<std::string> ort_disabled_ops_{};
|
||||||
|
void DisableOrtFP16OpTypes(const std::vector<std::string>& ops) {
|
||||||
|
ort_disabled_ops_.insert(ort_disabled_ops_.end(), ops.begin(), ops.end());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -12,8 +12,8 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/runtime/backends/ort/option.h"
|
|
||||||
#include "fastdeploy/pybind/main.h"
|
#include "fastdeploy/pybind/main.h"
|
||||||
|
#include "fastdeploy/runtime/backends/ort/option.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
@@ -29,7 +29,9 @@ void BindOrtOption(pybind11::module& m) {
|
|||||||
.def_readwrite("execution_mode", &OrtBackendOption::execution_mode)
|
.def_readwrite("execution_mode", &OrtBackendOption::execution_mode)
|
||||||
.def_readwrite("device", &OrtBackendOption::device)
|
.def_readwrite("device", &OrtBackendOption::device)
|
||||||
.def_readwrite("device_id", &OrtBackendOption::device_id)
|
.def_readwrite("device_id", &OrtBackendOption::device_id)
|
||||||
.def_readwrite("enable_fp16", &OrtBackendOption::enable_fp16);
|
.def_readwrite("enable_fp16", &OrtBackendOption::enable_fp16)
|
||||||
|
.def("disable_ort_fp16_op_types",
|
||||||
|
&OrtBackendOption::DisableOrtFP16OpTypes);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -180,16 +180,26 @@ bool OrtBackend::InitFromPaddle(const std::string& model_buffer,
|
|||||||
strcpy(ops[0].export_op_name, "MultiClassNMS");
|
strcpy(ops[0].export_op_name, "MultiClassNMS");
|
||||||
strcpy(ops[1].op_name, "pool2d");
|
strcpy(ops[1].op_name, "pool2d");
|
||||||
strcpy(ops[1].export_op_name, "AdaptivePool2d");
|
strcpy(ops[1].export_op_name, "AdaptivePool2d");
|
||||||
|
converted_to_fp16 = option.enable_fp16;
|
||||||
|
|
||||||
|
std::vector<char*> disable_fp16_ops;
|
||||||
|
for (auto i = 0; i < option.ort_disabled_ops_.size(); i++) {
|
||||||
|
auto one_type = option.ort_disabled_ops_[i];
|
||||||
|
char* charStr = new char[one_type.size() + 1];
|
||||||
|
std::strcpy(charStr, one_type.c_str());
|
||||||
|
disable_fp16_ops.push_back(charStr);
|
||||||
|
}
|
||||||
if (!paddle2onnx::Export(
|
if (!paddle2onnx::Export(
|
||||||
model_buffer.c_str(), model_buffer.size(), params_buffer.c_str(),
|
model_buffer.c_str(), model_buffer.size(), params_buffer.c_str(),
|
||||||
params_buffer.size(), &model_content_ptr, &model_content_size, 11,
|
params_buffer.size(), &model_content_ptr, &model_content_size, 11,
|
||||||
true, verbose, true, true, true, ops.data(), 2, "onnxruntime",
|
true, verbose, true, true, true, ops.data(), 2, "onnxruntime",
|
||||||
nullptr, 0, "", &save_external, false)) {
|
nullptr, 0, "", &save_external, option.enable_fp16,
|
||||||
|
disable_fp16_ops.data(), option.ort_disabled_ops_.size())) {
|
||||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string onnx_model_proto(model_content_ptr,
|
std::string onnx_model_proto(model_content_ptr,
|
||||||
model_content_ptr + model_content_size);
|
model_content_ptr + model_content_size);
|
||||||
delete[] model_content_ptr;
|
delete[] model_content_ptr;
|
||||||
@@ -219,7 +229,7 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::string onnx_model_buffer;
|
std::string onnx_model_buffer;
|
||||||
if (option.enable_fp16) {
|
if (!converted_to_fp16 && option.enable_fp16) {
|
||||||
if (option.device == Device::CPU) {
|
if (option.device == Device::CPU) {
|
||||||
FDWARNING << "Turning on FP16 on CPU may result in slower inference."
|
FDWARNING << "Turning on FP16 on CPU may result in slower inference."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
|
@@ -78,6 +78,9 @@ class OrtBackend : public BaseBackend {
|
|||||||
// the ONNX model file name,
|
// the ONNX model file name,
|
||||||
// when ONNX is bigger than 2G, we will set this name
|
// when ONNX is bigger than 2G, we will set this name
|
||||||
std::string model_file_name;
|
std::string model_file_name;
|
||||||
|
// recored if the model has been converted to fp16
|
||||||
|
bool converted_to_fp16 = false;
|
||||||
|
|
||||||
#ifndef NON_64_PLATFORM
|
#ifndef NON_64_PLATFORM
|
||||||
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
|
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
|
||||||
#endif
|
#endif
|
||||||
|
@@ -388,6 +388,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
} else if (value_string == "pd_fp16") {
|
} else if (value_string == "pd_fp16") {
|
||||||
// TODO(liqi): paddle inference don't currently have interface
|
// TODO(liqi): paddle inference don't currently have interface
|
||||||
// for fp16.
|
// for fp16.
|
||||||
|
} else if (value_string == "ort_fp16") {
|
||||||
|
runtime_options_->ort_option.enable_fp16 = true;
|
||||||
}
|
}
|
||||||
// } else if( param_key == "max_batch_size") {
|
// } else if( param_key == "max_batch_size") {
|
||||||
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
|
// THROW_IF_BACKEND_MODEL_ERROR(ParseUnsignedLongLongValue(
|
||||||
@@ -419,7 +421,8 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
bool use_paddle_trt;
|
bool use_paddle_trt;
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
ParseBoolValue(value_string, &use_paddle_trt));
|
ParseBoolValue(value_string, &use_paddle_trt));
|
||||||
runtime_options_->paddle_infer_option.enable_trt = use_paddle_trt;
|
runtime_options_->paddle_infer_option.enable_trt =
|
||||||
|
use_paddle_trt;
|
||||||
} else if (param_key == "use_paddle_log") {
|
} else if (param_key == "use_paddle_log") {
|
||||||
bool use_paddle_log;
|
bool use_paddle_log;
|
||||||
THROW_IF_BACKEND_MODEL_ERROR(
|
THROW_IF_BACKEND_MODEL_ERROR(
|
||||||
@@ -436,6 +439,12 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
|
|||||||
SplitStringByDelimiter(value_string, ' ', &disable_trt_ops);
|
SplitStringByDelimiter(value_string, ' ', &disable_trt_ops);
|
||||||
runtime_options_->paddle_infer_option.DisableTrtOps(
|
runtime_options_->paddle_infer_option.DisableTrtOps(
|
||||||
disable_trt_ops);
|
disable_trt_ops);
|
||||||
|
} else if (param_key == "disable_ort_fp16_op_types") {
|
||||||
|
std::vector<std::string> disable_ort_fp16_op_types;
|
||||||
|
SplitStringByDelimiter(value_string, ' ',
|
||||||
|
&disable_ort_fp16_op_types);
|
||||||
|
runtime_options_->ort_option.DisableOrtFP16OpTypes(
|
||||||
|
disable_ort_fp16_op_types);
|
||||||
} else if (param_key == "delete_passes") {
|
} else if (param_key == "delete_passes") {
|
||||||
std::vector<std::string> delete_passes;
|
std::vector<std::string> delete_passes;
|
||||||
SplitStringByDelimiter(value_string, ' ', &delete_passes);
|
SplitStringByDelimiter(value_string, ' ', &delete_passes);
|
||||||
|
Reference in New Issue
Block a user