diff --git a/csrc/fastdeploy/backends/paddle/paddle_backend.cc b/csrc/fastdeploy/backends/paddle/paddle_backend.cc index 2fae38937..f1d7605fc 100644 --- a/csrc/fastdeploy/backends/paddle/paddle_backend.cc +++ b/csrc/fastdeploy/backends/paddle/paddle_backend.cc @@ -26,6 +26,9 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) { config_.SetMkldnnCacheCapacity(option.mkldnn_cache_size); } } + if (!option.enable_log_info) { + config_.DisableGlogInfo(); + } config_.SetCpuMathLibraryNumThreads(option.cpu_thread_num); } diff --git a/csrc/fastdeploy/backends/paddle/paddle_backend.h b/csrc/fastdeploy/backends/paddle/paddle_backend.h index 99ca5eb1b..22078ab14 100644 --- a/csrc/fastdeploy/backends/paddle/paddle_backend.h +++ b/csrc/fastdeploy/backends/paddle/paddle_backend.h @@ -32,6 +32,8 @@ struct PaddleBackendOption { #endif bool enable_mkldnn = true; + bool enable_log_info = false; + int mkldnn_cache_size = 1; int cpu_thread_num = 8; // initialize memory size(MB) for GPU diff --git a/csrc/fastdeploy/fastdeploy_model.cc b/csrc/fastdeploy/fastdeploy_model.cc index c4dbc70a7..31781ac3a 100644 --- a/csrc/fastdeploy/fastdeploy_model.cc +++ b/csrc/fastdeploy/fastdeploy_model.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/fastdeploy_model.h" -#include "fastdeploy/utils/unique_ptr.h" #include "fastdeploy/utils/utils.h" namespace fastdeploy { @@ -54,12 +53,52 @@ bool FastDeployModel::InitRuntime() { << std::endl; return false; } - runtime_ = utils::make_unique(); - if (!runtime_->Init(runtime_option)) { - return false; + + bool use_gpu = (runtime_option.device == Device::GPU); +#ifndef WITH_GPU + use_gpu = false; +#endif + + // whether the model is supported by the setted backend + bool is_supported = false; + if (use_gpu) { + for (auto& item : valid_gpu_backends) { + if (item == runtime_option.backend) { + is_supported = true; + break; + } + } + } else { + for (auto& item : valid_cpu_backends) { + if (item == runtime_option.backend) { + is_supported = true; + break; + } + } + } + + if (is_supported) { + runtime_ = std::unique_ptr(new Runtime()); + if (!runtime_->Init(runtime_option)) { + return false; + } + runtime_initialized_ = true; + return true; + } else { + FDWARNING << ModelName() << " is not supported with backend " + << Str(runtime_option.backend) << "." << std::endl; + if (use_gpu) { + FDASSERT(valid_gpu_backends.size() > 0, + "There's no valid gpu backend for " + ModelName() + "."); + FDWARNING << "FastDeploy will choose " << Str(valid_gpu_backends[0]) + << " for model inference." << std::endl; + } else { + FDASSERT(valid_gpu_backends.size() > 0, + "There's no valid cpu backend for " + ModelName() + "."); + FDWARNING << "FastDeploy will choose " << Str(valid_cpu_backends[0]) + << " for model inference." << std::endl; + } } - runtime_initialized_ = true; - return true; } if (runtime_option.device == Device::CPU) { diff --git a/csrc/fastdeploy/fastdeploy_runtime.cc b/csrc/fastdeploy/fastdeploy_runtime.cc index e5c41a29a..c2a16b903 100644 --- a/csrc/fastdeploy/fastdeploy_runtime.cc +++ b/csrc/fastdeploy/fastdeploy_runtime.cc @@ -181,6 +181,10 @@ void RuntimeOption::EnablePaddleMKLDNN() { pd_enable_mkldnn = true; } void RuntimeOption::DisablePaddleMKLDNN() { pd_enable_mkldnn = false; } +void RuntimeOption::EnablePaddleLogInfo() { pd_enable_log_info = true; } + +void RuntimeOption::DisablePaddleLogInfo() { pd_enable_log_info = false; } + void RuntimeOption::SetPaddleMKLDNNCacheSize(int size) { FDASSERT(size > 0, "Parameter size must greater than 0."); pd_mkldnn_cache_size = size; @@ -272,6 +276,7 @@ void Runtime::CreatePaddleBackend() { #ifdef ENABLE_PADDLE_BACKEND auto pd_option = PaddleBackendOption(); pd_option.enable_mkldnn = option.pd_enable_mkldnn; + pd_option.enable_log_info = option.pd_enable_log_info; pd_option.mkldnn_cache_size = option.pd_mkldnn_cache_size; pd_option.use_gpu = (option.device == Device::GPU) ? true : false; pd_option.gpu_id = option.device_id; diff --git a/csrc/fastdeploy/fastdeploy_runtime.h b/csrc/fastdeploy/fastdeploy_runtime.h index 780945458..ab6b4a188 100644 --- a/csrc/fastdeploy/fastdeploy_runtime.h +++ b/csrc/fastdeploy/fastdeploy_runtime.h @@ -68,6 +68,11 @@ struct FASTDEPLOY_DECL RuntimeOption { // disable mkldnn while use paddle inference in CPU void DisablePaddleMKLDNN(); + // enable debug information of paddle backend + void EnablePaddleLogInfo(); + // disable debug information of paddle backend + void DisablePaddleLogInfo(); + // set size of cached shape while enable mkldnn with paddle inference backend void SetPaddleMKLDNNCacheSize(int size); @@ -108,6 +113,7 @@ struct FASTDEPLOY_DECL RuntimeOption { // ======Only for Paddle Backend===== bool pd_enable_mkldnn = true; + bool pd_enable_log_info = false; int pd_mkldnn_cache_size = 1; // ======Only for Trt Backend======= diff --git a/csrc/fastdeploy/pybind/fastdeploy_runtime.cc b/csrc/fastdeploy/pybind/fastdeploy_runtime.cc index 412b1ccef..86e5b69c7 100644 --- a/csrc/fastdeploy/pybind/fastdeploy_runtime.cc +++ b/csrc/fastdeploy/pybind/fastdeploy_runtime.cc @@ -28,6 +28,8 @@ void BindRuntime(pybind11::module& m) { .def("use_trt_backend", &RuntimeOption::UseTrtBackend) .def("enable_paddle_mkldnn", &RuntimeOption::EnablePaddleMKLDNN) .def("disable_paddle_mkldnn", &RuntimeOption::DisablePaddleMKLDNN) + .def("enable_paddle_log_info", &RuntimeOption::EnablePaddleLogInfo) + .def("disable_paddle_log_info", &RuntimeOption::DisablePaddleLogInfo) .def("set_paddle_mkldnn_cache_size", &RuntimeOption::SetPaddleMKLDNNCacheSize) .def("set_trt_input_shape", &RuntimeOption::SetTrtInputShape) diff --git a/csrc/fastdeploy/vision/detection/ppdet/picodet.cc b/csrc/fastdeploy/vision/detection/ppdet/picodet.cc index d89fab2ae..7c961d1f8 100644 --- a/csrc/fastdeploy/vision/detection/ppdet/picodet.cc +++ b/csrc/fastdeploy/vision/detection/ppdet/picodet.cc @@ -24,8 +24,8 @@ PicoDet::PicoDet(const std::string& model_file, const std::string& params_file, const RuntimeOption& custom_option, const Frontend& model_format) { config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; - valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; diff --git a/csrc/fastdeploy/vision/detection/ppdet/ppyoloe.cc b/csrc/fastdeploy/vision/detection/ppdet/ppyoloe.cc index 2e4b56ecb..12786a08a 100644 --- a/csrc/fastdeploy/vision/detection/ppdet/ppyoloe.cc +++ b/csrc/fastdeploy/vision/detection/ppdet/ppyoloe.cc @@ -14,8 +14,8 @@ PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file, const RuntimeOption& custom_option, const Frontend& model_format) { config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; - valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; diff --git a/csrc/fastdeploy/vision/detection/ppdet/yolov3.cc b/csrc/fastdeploy/vision/detection/ppdet/yolov3.cc index 309d65640..8de0ec231 100644 --- a/csrc/fastdeploy/vision/detection/ppdet/yolov3.cc +++ b/csrc/fastdeploy/vision/detection/ppdet/yolov3.cc @@ -23,8 +23,8 @@ YOLOv3::YOLOv3(const std::string& model_file, const std::string& params_file, const RuntimeOption& custom_option, const Frontend& model_format) { config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER}; - valid_gpu_backends = {Backend::PDINFER}; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; diff --git a/csrc/fastdeploy/vision/detection/ppdet/yolox.cc b/csrc/fastdeploy/vision/detection/ppdet/yolox.cc index a60ebfcc4..dbf0824ba 100644 --- a/csrc/fastdeploy/vision/detection/ppdet/yolox.cc +++ b/csrc/fastdeploy/vision/detection/ppdet/yolox.cc @@ -18,12 +18,14 @@ namespace fastdeploy { namespace vision { namespace detection { -PaddleYOLOX::PaddleYOLOX(const std::string& model_file, const std::string& params_file, - const std::string& config_file, const RuntimeOption& custom_option, - const Frontend& model_format) { +PaddleYOLOX::PaddleYOLOX(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { config_file_ = config_file; - valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; - valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_cpu_backends = {Backend::ORT, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; runtime_option = custom_option; runtime_option.model_format = model_format; runtime_option.model_file = model_file; diff --git a/csrc/fastdeploy/vision/visualize/detection.cc b/csrc/fastdeploy/vision/visualize/detection.cc index 147ef6556..693e9da72 100644 --- a/csrc/fastdeploy/vision/visualize/detection.cc +++ b/csrc/fastdeploy/vision/visualize/detection.cc @@ -24,13 +24,17 @@ namespace vision { // If need to visualize num_classes > 1000 // Please call Visualize::GetColorMap(num_classes) first cv::Mat Visualize::VisDetection(const cv::Mat& im, - const DetectionResult& result, int line_size, + const DetectionResult& result, + float score_threshold, int line_size, float font_size) { auto color_map = GetColorMap(); int h = im.rows; int w = im.cols; auto vis_im = im.clone(); for (size_t i = 0; i < result.boxes.size(); ++i) { + if (result.scores[i] < score_threshold) { + continue; + } cv::Rect rect(result.boxes[i][0], result.boxes[i][1], result.boxes[i][2] - result.boxes[i][0], result.boxes[i][3] - result.boxes[i][1]); diff --git a/csrc/fastdeploy/vision/visualize/visualize.h b/csrc/fastdeploy/vision/visualize/visualize.h index bee62c301..e8709d730 100644 --- a/csrc/fastdeploy/vision/visualize/visualize.h +++ b/csrc/fastdeploy/vision/visualize/visualize.h @@ -26,7 +26,8 @@ class FASTDEPLOY_DECL Visualize { static std::vector color_map_; static const std::vector& GetColorMap(int num_classes = 1000); static cv::Mat VisDetection(const cv::Mat& im, const DetectionResult& result, - int line_size = 1, float font_size = 0.5f); + float score_threshold = 0.0, int line_size = 1, + float font_size = 0.5f); static cv::Mat VisFaceDetection(const cv::Mat& im, const FaceDetectionResult& result, int line_size = 1, float font_size = 0.5f); diff --git a/csrc/fastdeploy/vision/visualize/visualize_pybind.cc b/csrc/fastdeploy/vision/visualize/visualize_pybind.cc index 36010acf1..508ac84c6 100644 --- a/csrc/fastdeploy/vision/visualize/visualize_pybind.cc +++ b/csrc/fastdeploy/vision/visualize/visualize_pybind.cc @@ -20,10 +20,10 @@ void BindVisualize(pybind11::module& m) { .def(pybind11::init<>()) .def_static("vis_detection", [](pybind11::array& im_data, vision::DetectionResult& result, - int line_size, float font_size) { + float score_threshold, int line_size, float font_size) { auto im = PyArrayToCvMat(im_data); auto vis_im = vision::Visualize::VisDetection( - im, result, line_size, font_size); + im, result, score_threshold, line_size, font_size); FDTensor out; vision::Mat(vis_im).ShareWithTensor(&out); return TensorToPyArray(out); diff --git a/examples/vision/detection/paddledetection/README.md b/examples/vision/detection/paddledetection/README.md new file mode 100644 index 000000000..4ab0411c2 --- /dev/null +++ b/examples/vision/detection/paddledetection/README.md @@ -0,0 +1,45 @@ +# PaddleDetection模型部署 + +## 模型版本说明 + +- [PaddleDetection Release/2.4](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4) + +## 支持模型列表 + +目前FastDeploy支持如下模型的部署 + +- [PPYOLOE系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe) +- [PicoDet系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/picodet) +- [PPYOLO系列模型(含v2)](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyolo) +- [YOLOv3系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolov3) +- [YOLOX系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolox) +- [FasterRCNN系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/faster_rcnn) + +## 导出部署模型 + +在部署前,需要先将PaddleDetection导出成部署模型,导出步骤参考文档[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.4/deploy/EXPORT_MODEL.md) + +注意:在导出模型时不要进行NMS的去除操作,正常导出即可。 + +## 下载预训练模型 + +为了方便开发者的测试,下面提供了PaddleDetection导出的各系列模型,开发者可直接下载使用。 + +其中精度指标来源于PaddleDetection中对各模型的介绍,详情各参考PaddleDetection中的说明。 + + +| 模型 | 参数大小 | 精度 | 备注 | +|:---------------------------------------------------------------- |:----- |:----- | :------ | +| [picodet_l_320_coco_lcnet](https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz) |23MB | 42.6% | +| [ppyoloe_crn_l_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz) |200MB | 51.4% | +| [ppyolo_r50vd_dcn_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolo_r50vd_dcn_1x_coco.tgz) | 180MB | 44.8% | 暂不支持TensorRT | +| [ppyolov2_r101vd_dcn_365e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz) | 282MB | 49.7% | 暂不支持TensorRT | +| [yolov3_darknet53_270e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz) |237MB | 39.1% | | +| [yolox_s_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz) | 35MB | 40.4% | | +| [faster_rcnn_r50_vd_fpn_2x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz) | 160MB | 40.8%| 暂不支持TensorRT | + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) diff --git a/examples/vision/detection/paddledetection/cpp/CMakeLists.txt b/examples/vision/detection/paddledetection/cpp/CMakeLists.txt new file mode 100644 index 000000000..b570aa2db --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/CMakeLists.txt @@ -0,0 +1,28 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.12) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_faster_rcnn_demo ${PROJECT_SOURCE_DIR}/infer_faster_rcnn.cc) +target_link_libraries(infer_faster_rcnn_demo ${FASTDEPLOY_LIBS}) + +add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc) +target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS}) + +add_executable(infer_picodet_demo ${PROJECT_SOURCE_DIR}/infer_picodet.cc) +target_link_libraries(infer_picodet_demo ${FASTDEPLOY_LIBS}) + +add_executable(infer_yolox_demo ${PROJECT_SOURCE_DIR}/infer_yolox.cc) +target_link_libraries(infer_yolox_demo ${FASTDEPLOY_LIBS}) + +add_executable(infer_yolov3_demo ${PROJECT_SOURCE_DIR}/infer_yolov3.cc) +target_link_libraries(infer_yolov3_demo ${FASTDEPLOY_LIBS}) + +add_executable(infer_ppyolo_demo ${PROJECT_SOURCE_DIR}/infer_ppyolo.cc) +target_link_libraries(infer_ppyolo_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/paddledetection/cpp/README.md b/examples/vision/detection/paddledetection/cpp/README.md new file mode 100644 index 000000000..bb0c6ed71 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/README.md @@ -0,0 +1,75 @@ +# PaddleDetection C++部署示例 + +本目录下提供`infer_xxx.cc`快速完成PaddleDetection模型包括PPYOLOE/PicoDet/YOLOX/YOLOv3/PPYOLO/FasterRCNN在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/compile/prebuilt_libraries.md) + +以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试 + +``` +mkdir build +cd build +wget https://bj.bcebos.com/paddlehub/fastdeploy/libs/0.2.0/fastdeploy-linux-x64-gpu-0.2.0.tgz +tar xvf fastdeploy-linux-x64-gpu-0.2.0.tgz +cd fastdeploy-linux-x64-gpu-0.2.0/examples/vision/detection/paddledetection +mkdir build && cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../fastdeploy-linux-x64-gpu-0.2.0 +make -j + +# 下载PPYOLOE模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000087038.jpg +tar xvf picodet_l_320_coco_lcnet.tgz + + +# CPU推理 +./infer_ppyoloe_demo ./picodet_l_320_coco_lcnet 000000087038.jpg 0 +# GPU推理 +./infer_ppyoloe_demo ./picodet_l_320_coco_lcnet 000000087038.jpg 1 +# GPU上TensorRT推理 +./infer_ppyoloe_demo ./picodet_l_320_coco_lcnet 000000087038.jpg 2 +``` + +## PaddleDetection C++接口 + +### 模型类 + +PaddleDetection目前支持6种模型系列,类名分别为`PPYOLOE`, `PicoDet`, `PaddleYOLOX`, `PPYOLO`, `FasterRCNN`,所有类名的构造函数和预测函数在参数上完全一致,本文档以PPYOLOE为例讲解API +``` +fastdeploy::vision::detection::PPYOLOE( + const string& model_file, + const string& params_file, + const string& config_file + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::PADDLE) +``` + +PaddleDetection PPYOLOE模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 配置文件路径,即PaddleDetection导出的部署yaml文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为PADDLE格式 + +#### Predict函数 + +> ``` +> PPYOLOE::Predict(cv::Mat* im, DetectionResult* result) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/paddledetection/cpp/infer_faster_rcnn.cc b/examples/vision/detection/paddledetection/cpp/infer_faster_rcnn.cc new file mode 100644 index 000000000..7bd7fd91c --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_faster_rcnn.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::FasterRCNN( + model_file, params_file, config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::FasterRCNN( + model_file, params_file, config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./faster_rcnn_r50_vd_fpn_2x_coco ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/cpp/infer_picodet.cc b/examples/vision/detection/paddledetection/cpp/infer_picodet.cc new file mode 100644 index 000000000..19c2a6837 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_picodet.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::PicoDet(model_file, params_file, + config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::PicoDet(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("image", {1, 3, 320, 320}); + option.SetTrtInputShape("scale_Factor", {1, 2}); + auto model = fastdeploy::vision::detection::PicoDet(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./picodet_model_dir ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/cpp/infer_ppyolo.cc b/examples/vision/detection/paddledetection/cpp/infer_ppyolo.cc new file mode 100644 index 000000000..a111e70f5 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_ppyolo.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::PPYOLO(model_file, params_file, + config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::PPYOLO(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./ppyolo_dirname ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/cpp/infer_ppyoloe.cc b/examples/vision/detection/paddledetection/cpp/infer_ppyoloe.cc new file mode 100644 index 000000000..ec01d3914 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_ppyoloe.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file, + config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("image", {1, 3, 640, 640}); + option.SetTrtInputShape("scale_factor", {1, 2}); + auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./ppyoloe_model_dir ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/cpp/infer_yolov3.cc b/examples/vision/detection/paddledetection/cpp/infer_yolov3.cc new file mode 100644 index 000000000..54571b8d2 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_yolov3.cc @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::YOLOv3(model_file, params_file, + config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::YOLOv3(model_file, params_file, + config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./ppyolo_dirname ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/cpp/infer_yolox.cc b/examples/vision/detection/paddledetection/cpp/infer_yolox.cc new file mode 100644 index 000000000..8eb9b4224 --- /dev/null +++ b/examples/vision/detection/paddledetection/cpp/infer_yolox.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::detection::PaddleYOLOX( + model_file, params_file, config_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::PaddleYOLOX( + model_file, params_file, config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_dir, const std::string& image_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("image", {1, 3, 640, 640}); + option.SetTrtInputShape("scale_factor", {1, 2}); + auto model = fastdeploy::vision::detection::PaddleYOLOX( + model_file, params_file, config_file, option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./paddle_yolox_dirname ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu by tensorrt." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/paddledetection/python/README.md b/examples/vision/detection/paddledetection/python/README.md new file mode 100644 index 000000000..3863481ab --- /dev/null +++ b/examples/vision/detection/paddledetection/python/README.md @@ -0,0 +1,72 @@ +# PaddleDetection Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md) + +本目录下提供`infer_xxx.py`快速完成PPYOLOE/PicoDet等模型在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +``` +#下载PPYOLOE模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg +tar xvf ppyoloe_crn_l_300e_coco.tgz + +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd examples/vison/detection/paddledetection/python/ + +# CPU推理 +python infer.py --model_dir ppyoloe_crn_l_300e_coco --image 000000087038.jpg --device cpu +# GPU推理 +python infer.py --model_dir ppyoloe_crn_l_300e_coco --image 000000087038.jpg --device gpu +# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model_dir ppyoloe_crn_l_300e_coco --image 000000087038.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 + +## PaddleDetection Python接口 + +``` +fastdeploy.vision.detection.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +fastdeploy.vision.detection.PicoDet(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +fastdeploy.vision.detection.PaddleYOLOX(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +fastdeploy.vision.detection.YOLOv3(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +fastdeploy.vision.detection.PPYOLO(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +fastdeploy.vision.detection.FasterRCNN(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE) +``` + +PaddleDetection模型加载和初始化,其中model_file, params_file为导出的Paddle部署模型格式, config_file为PaddleDetection同时导出的部署配置yaml文件 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 推理配置yaml文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为Paddle + +### predict函数 + +PaddleDetection中各个模型,包括PPYOLOE/PicoDet/PaddleYOLOX/YOLOv3/PPYOLO/FasterRCNN,均提供如下同样的成员函数用于进去图像的检测 +> ``` +> PPYOLOE.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> ``` +> +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 + +> **返回** +> +> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +## 其它文档 + +- [PaddleDetection 模型介绍](..) +- [PaddleDetection C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/paddledetection/python/infer_faster_rcnn.py b/examples/vision/detection/paddledetection/python/infer_faster_rcnn.py new file mode 100644 index 000000000..1100aa8a6 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_faster_rcnn.py @@ -0,0 +1,61 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 640, 640]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.FasterRCNN( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/detection/paddledetection/python/infer_picodet.py b/examples/vision/detection/paddledetection/python/infer_picodet.py new file mode 100644 index 000000000..06bfad03c --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_picodet.py @@ -0,0 +1,61 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 320, 320]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.PicoDet( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/detection/paddledetection/python/infer_ppyolo.py b/examples/vision/detection/paddledetection/python/infer_ppyolo.py new file mode 100644 index 000000000..029f3dc21 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_ppyolo.py @@ -0,0 +1,62 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 640, 640]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.PPYOLO( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection( + im, result, score_threshold=0.5, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/detection/paddledetection/python/infer_ppyoloe.py b/examples/vision/detection/paddledetection/python/infer_ppyoloe.py new file mode 100644 index 000000000..ae533a509 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_ppyoloe.py @@ -0,0 +1,61 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 640, 640]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.PPYOLOE( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/detection/paddledetection/python/infer_yolov3.py b/examples/vision/detection/paddledetection/python/infer_yolov3.py new file mode 100644 index 000000000..7ea372ff2 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_yolov3.py @@ -0,0 +1,62 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 608, 608]) + option.set_trt_input_shape("im_shape", [1, 2]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.YOLOv3( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/examples/vision/detection/paddledetection/python/infer_yolox.py b/examples/vision/detection/paddledetection/python/infer_yolox.py new file mode 100644 index 000000000..f65b1d8b1 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/infer_yolox.py @@ -0,0 +1,61 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="Path of PaddleDetection model directory") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("image", [1, 3, 640, 640]) + option.set_trt_input_shape("scale_factor", [1, 2]) + return option + + +args = parse_arguments() + +model_file = os.path.join(args.model_dir, "model.pdmodel") +params_file = os.path.join(args.model_dir, "model.pdiparams") +config_file = os.path.join(args.model_dir, "infer_cfg.yml") + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.PaddleYOLOX( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/runtime.py b/fastdeploy/runtime.py index 38c895498..43936533d 100644 --- a/fastdeploy/runtime.py +++ b/fastdeploy/runtime.py @@ -81,6 +81,12 @@ class RuntimeOption: def disable_paddle_mkldnn(self): return self._option.disable_paddle_mkldnn() + def enable_paddle_log_info(self): + return self._option.enable_paddle_log_info() + + def disable_paddle_log_info(self): + return self._option.disable_paddle_log_info() + def set_paddle_mkldnn_cache_size(self, cache_size): return self._option.set_paddle_mkldnn_cache_size(cache_size) diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index e99c019c3..97c81eb50 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -14,11 +14,12 @@ from __future__ import absolute_import from . import detection +from . import classification + from . import matting from . import facedet from . import faceid -from . import ppcls from . import ppseg from . import evaluation from .visualize import * diff --git a/fastdeploy/vision/visualize/__init__.py b/fastdeploy/vision/visualize/__init__.py index b2b8e90ad..faa54f824 100644 --- a/fastdeploy/vision/visualize/__init__.py +++ b/fastdeploy/vision/visualize/__init__.py @@ -17,9 +17,13 @@ import logging from ... import c_lib_wrap as C -def vis_detection(im_data, det_result, line_size=1, font_size=0.5): - return C.vision.Visualize.vis_detection(im_data, det_result, line_size, - font_size) +def vis_detection(im_data, + det_result, + score_threshold=0.0, + line_size=1, + font_size=0.5): + return C.vision.Visualize.vis_detection( + im_data, det_result, score_threshold, line_size, font_size) def vis_face_detection(im_data, face_det_result, line_size=1, font_size=0.5):