[Backend] Add pybind & PaddleDetection example for TVM (#1998)

* update

* update

* Update infer_ppyoloe_demo.cc

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
Zheng-Bicheng
2023-06-04 13:26:47 +08:00
committed by GitHub
parent c634a9260d
commit 8d357814e8
10 changed files with 189 additions and 24 deletions

View File

@@ -30,7 +30,7 @@ set(DLPACK_PATH "${THIRD_PARTY_PATH}/install/dlpack")
execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DLPACK_PATH}") execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${DLPACK_PATH}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory
"${PROJECT_SOURCE_DIR}/third_party/dlpack" "${PROJECT_SOURCE_DIR}/third_party/dlpack"
"${THIRD_PARTY_PATH}/install/") "${THIRD_PARTY_PATH}/install/dlpack")
include_directories(${DLPACK_PATH}/include) include_directories(${DLPACK_PATH}/include)
set(DMLC_CORE_PATH "${THIRD_PARTY_PATH}/install/dmlc-core") set(DMLC_CORE_PATH "${THIRD_PARTY_PATH}/install/dmlc-core")

View File

@@ -4,7 +4,7 @@
本目录下提供`infer_ppyoloe_demo.cc`快速完成PPDetection模型使用TVM加速部署的示例。 本目录下提供`infer_ppyoloe_demo.cc`快速完成PPDetection模型使用TVM加速部署的示例。
## 转换模型并运行 ## 运行
```bash ```bash
# build example # build example

View File

@@ -15,8 +15,8 @@
#include "fastdeploy/vision.h" #include "fastdeploy/vision.h"
void TVMInfer(const std::string& model_dir, const std::string& image_file) { void TVMInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + "/tvm_model"; auto model_file = model_dir + "/tvm_model.so";
auto params_file = ""; auto params_file = model_dir + "/tvm_model.params";
auto config_file = model_dir + "/infer_cfg.yml"; auto config_file = model_dir + "/infer_cfg.yml";
auto option = fastdeploy::RuntimeOption(); auto option = fastdeploy::RuntimeOption();
@@ -54,4 +54,4 @@ int main(int argc, char* argv[]) {
TVMInfer(argv[1], argv[2]); TVMInfer(argv[1], argv[2]);
return 0; return 0;
} }

View File

@@ -0,0 +1,80 @@
[English](README.md) | 简体中文
# PaddleDetection Python部署示例
本目录下提供`infer_ppyoloe_demo.cc`快速完成PPDetection模型使用TVM加速部署的示例。
## 运行
```bash
# copy model to example folder
cp -r /path/to/model ./
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
python infer_ppyoloe.py --model_dir tvm_save --image 000000014439.jpg --device cpu
```
运行完成可视化结果如下图所示
<div align="center">
<img src="https://user-images.githubusercontent.com/19339784/184326520-7075e907-10ed-4fad-93f8-52d0e35d4964.jpg", width=480px, height=320px />
</div>
## PaddleDetection Python接口
```python
fastdeploy.vision.detection.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PicoDet(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PaddleYOLOX(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.YOLOv3(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PPYOLO(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.FasterRCNN(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.MaskRCNN(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.SSD(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PaddleYOLOv5(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PaddleYOLOv6(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PaddleYOLOv7(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.RTMDet(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.CascadeRCNN(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PSSDet(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.RetinaNet(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.PPYOLOESOD(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.FCOS(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.TTFNet(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.TOOD(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.GFL(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
```
PaddleDetection模型加载和初始化其中model_file params_file为导出的Paddle部署模型格式, config_file为PaddleDetection同时导出的部署配置yaml文件
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 推理配置yaml文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为Paddle
### predict函数
PaddleDetection中各个模型包括PPYOLOE/PicoDet/PaddleYOLOX/YOLOv3/PPYOLO/FasterRCNN均提供如下同样的成员函数用于进行图像的检测
> ```python
> PPYOLOE.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> **返回**
>
> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [PaddleDetection 模型介绍](../..)
- [PaddleDetection C++部署](../cpp)
- [模型预测结果说明](../../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,68 @@
import cv2
import os
import fastdeploy as fd
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
default=None,
help="Path of PaddleDetection model directory")
parser.add_argument(
"--image", default=None, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'kunlunxin', 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
option.use_cpu()
option.use_tvm_backend()
return option
args = parse_arguments()
if args.model_dir is None:
model_dir = fd.download_model(name='ppyoloe_crn_l_300e_coco')
else:
model_dir = args.model_dir
model_file = os.path.join(model_dir, "tvm_model.so")
params_file = os.path.join(model_dir, "tvm_model.params")
config_file = os.path.join(model_dir, "infer_cfg.yml")
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.detection.PPYOLOE(
model_file,
params_file,
config_file,
runtime_option=runtime_option,
model_format=fd.ModelFormat.TVMFormat)
model.postprocessor.apply_nms()
# 预测图片检测结果
if args.image is None:
image = fd.utils.get_detection_test_image()
else:
image = args.image
im = cv2.imread(image)
result = model.predict(im)
print(result)
# 预测结果可视化
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -136,6 +136,7 @@ void BindRuntime(pybind11::module& m) {
.value("PDINFER", Backend::PDINFER) .value("PDINFER", Backend::PDINFER)
.value("RKNPU2", Backend::RKNPU2) .value("RKNPU2", Backend::RKNPU2)
.value("SOPHGOTPU", Backend::SOPHGOTPU) .value("SOPHGOTPU", Backend::SOPHGOTPU)
.value("TVM", Backend::TVM)
.value("LITE", Backend::LITE); .value("LITE", Backend::LITE);
pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(), pybind11::enum_<ModelFormat>(m, "ModelFormat", pybind11::arithmetic(),
"ModelFormat for inference.") "ModelFormat for inference.")
@@ -143,7 +144,8 @@ void BindRuntime(pybind11::module& m) {
.value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT) .value("TORCHSCRIPT", ModelFormat::TORCHSCRIPT)
.value("RKNN", ModelFormat::RKNN) .value("RKNN", ModelFormat::RKNN)
.value("SOPHGO", ModelFormat::SOPHGO) .value("SOPHGO", ModelFormat::SOPHGO)
.value("ONNX", ModelFormat::ONNX); .value("ONNX", ModelFormat::ONNX)
.value("TVMFormat", ModelFormat::TVMFormat);
pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(), pybind11::enum_<Device>(m, "Device", pybind11::arithmetic(),
"Device for inference.") "Device for inference.")
.value("CPU", Device::CPU) .value("CPU", Device::CPU)

View File

@@ -59,14 +59,13 @@ bool TVMBackend::InitInputAndOutputTensor() {
bool TVMBackend::BuildModel(const RuntimeOption& runtime_option) { bool TVMBackend::BuildModel(const RuntimeOption& runtime_option) {
// load in the library // load in the library
tvm::runtime::Module mod_factory = tvm::runtime::Module mod_factory =
tvm::runtime::Module::LoadFromFile(runtime_option.model_file + ".so"); tvm::runtime::Module::LoadFromFile(runtime_option.model_file);
// create the graph executor module // create the graph executor module
gmod_ = mod_factory.GetFunction("default")(dev_); gmod_ = mod_factory.GetFunction("default")(dev_);
// load params // load params
std::ifstream params_in(runtime_option.model_file + ".params", std::ifstream params_in(runtime_option.params_file, std::ios::binary);
std::ios::binary);
std::string params_data((std::istreambuf_iterator<char>(params_in)), std::string params_data((std::istreambuf_iterator<char>(params_in)),
std::istreambuf_iterator<char>()); std::istreambuf_iterator<char>());
params_in.close(); params_in.close();

7
fastdeploy/runtime/option_pybind.cc Executable file → Normal file
View File

@@ -43,8 +43,10 @@ void BindOption(pybind11::module& m) {
.def("use_sophgo", &RuntimeOption::UseSophgo) .def("use_sophgo", &RuntimeOption::UseSophgo)
.def("use_ascend", &RuntimeOption::UseAscend) .def("use_ascend", &RuntimeOption::UseAscend)
.def("use_kunlunxin", &RuntimeOption::UseKunlunXin) .def("use_kunlunxin", &RuntimeOption::UseKunlunXin)
.def("disable_valid_backend_check",&RuntimeOption::DisableValidBackendCheck) .def("disable_valid_backend_check",
.def("enable_valid_backend_check",&RuntimeOption::EnableValidBackendCheck) &RuntimeOption::DisableValidBackendCheck)
.def("enable_valid_backend_check",
&RuntimeOption::EnableValidBackendCheck)
.def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option) .def_readwrite("paddle_lite_option", &RuntimeOption::paddle_lite_option)
.def_readwrite("openvino_option", &RuntimeOption::openvino_option) .def_readwrite("openvino_option", &RuntimeOption::openvino_option)
.def_readwrite("ort_option", &RuntimeOption::ort_option) .def_readwrite("ort_option", &RuntimeOption::ort_option)
@@ -59,6 +61,7 @@ void BindOption(pybind11::module& m) {
.def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum) .def("set_cpu_thread_num", &RuntimeOption::SetCpuThreadNum)
.def("use_paddle_backend", &RuntimeOption::UsePaddleBackend) .def("use_paddle_backend", &RuntimeOption::UsePaddleBackend)
.def("use_poros_backend", &RuntimeOption::UsePorosBackend) .def("use_poros_backend", &RuntimeOption::UsePorosBackend)
.def("use_tvm_backend", &RuntimeOption::UseTVMBackend)
.def("use_ort_backend", &RuntimeOption::UseOrtBackend) .def("use_ort_backend", &RuntimeOption::UseOrtBackend)
.def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("use_trt_backend", &RuntimeOption::UseTrtBackend)
.def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend) .def("use_openvino_backend", &RuntimeOption::UseOpenVINOBackend)

View File

@@ -244,10 +244,9 @@ class RuntimeOption:
:param enable_multi_stream: (bool)Whether to enable the multi stream of KunlunXin XPU. :param enable_multi_stream: (bool)Whether to enable the multi stream of KunlunXin XPU.
:param gm_default_size The default size of context global memory of KunlunXin XPU. :param gm_default_size The default size of context global memory of KunlunXin XPU.
""" """
return self._option.use_kunlunxin(device_id, l3_workspace_size, locked, return self._option.use_kunlunxin(
autotune, autotune_file, precision, device_id, l3_workspace_size, locked, autotune, autotune_file,
adaptive_seqlen, enable_multi_stream, precision, adaptive_seqlen, enable_multi_stream, gm_default_size)
gm_default_size)
def use_cpu(self): def use_cpu(self):
"""Inference with CPU """Inference with CPU
@@ -271,7 +270,7 @@ class RuntimeOption:
def disable_valid_backend_check(self): def disable_valid_backend_check(self):
""" Disable checking validity of backend during inference """ Disable checking validity of backend during inference
""" """
return self._option.disable_valid_backend_check() return self._option.disable_valid_backend_check()
def enable_valid_backend_check(self): def enable_valid_backend_check(self):
@@ -316,6 +315,11 @@ class RuntimeOption:
""" """
return self._option.use_ort_backend() return self._option.use_ort_backend()
def use_tvm_backend(self):
"""Use TVM Runtime backend, support inference TVM model on CPU.
"""
return self._option.use_tvm_backend()
def use_trt_backend(self): def use_trt_backend(self):
"""Use TensorRT backend, support inference Paddle/ONNX model on Nvidia GPU. """Use TensorRT backend, support inference Paddle/ONNX model on Nvidia GPU.
""" """

View File

@@ -57,12 +57,18 @@ setup_configs = dict()
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main" setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
# Backend options # Backend options
setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND", "OFF") setup_configs["ENABLE_TVM_BACKEND"] = os.getenv("ENABLE_TVM_BACKEND", "OFF")
setup_configs["ENABLE_SOPHGO_BACKEND"] = os.getenv("ENABLE_SOPHGO_BACKEND", "OFF") setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND",
"OFF")
setup_configs["ENABLE_SOPHGO_BACKEND"] = os.getenv("ENABLE_SOPHGO_BACKEND",
"OFF")
setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "OFF") setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "OFF")
setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND", "OFF") setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", "OFF") "OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF") setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
"OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF") setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
@@ -82,11 +88,14 @@ setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")
setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "") setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "")
# Custom deps settings # Custom deps settings
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
"/usr/local/cuda")
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "") setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "") setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "")
setup_configs["PADDLEINFERENCE_DIRECTORY"] = os.getenv("PADDLEINFERENCE_DIRECTORY", "") setup_configs["PADDLEINFERENCE_DIRECTORY"] = os.getenv(
setup_configs["PADDLEINFERENCE_VERSION"] = os.getenv("PADDLEINFERENCE_VERSION", "") "PADDLEINFERENCE_DIRECTORY", "")
setup_configs["PADDLEINFERENCE_VERSION"] = os.getenv("PADDLEINFERENCE_VERSION",
"")
setup_configs["PADDLEINFERENCE_URL"] = os.getenv("PADDLEINFERENCE_URL", "") setup_configs["PADDLEINFERENCE_URL"] = os.getenv("PADDLEINFERENCE_URL", "")
setup_configs["PADDLE2ONNX_URL"] = os.getenv("PADDLE2ONNX_URL", "") setup_configs["PADDLE2ONNX_URL"] = os.getenv("PADDLE2ONNX_URL", "")
setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "") setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "")