Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents
This commit is contained in:
Jason
2022-07-19 13:40:16 +08:00
committed by GitHub
parent f8c3906c51
commit 6ae7a16c36
22 changed files with 645 additions and 16 deletions

View File

@@ -1,4 +1,4 @@
function(add_fastdeploy_executable field url model)
function(add_fastdeploy_executable field url model)
# temp target name/file var in function scope
set(TEMP_TARGET_FILE ${PROJECT_SOURCE_DIR}/examples/${field}/${url}_${model}.cc)
set(TEMP_TARGET_NAME ${field}_${url}_${model})
@@ -7,7 +7,7 @@ function(add_fastdeploy_executable field url model)
target_link_libraries(${TEMP_TARGET_NAME} PUBLIC fastdeploy)
message(STATUS "Found source file: [${field}/${url}_${model}.cc], ADD!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !")
else ()
message(WARNING "Can not found source file: [${field}/${url}_${model}.cc], SKIP!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !")
message(WARNING "Can not found source file: [${field}/${url}_${model}.cc], SKIP!!! fastdeploy executable: [${TEMP_TARGET_NAME}] !")
endif()
unset(TEMP_TARGET_FILE)
unset(TEMP_TARGET_NAME)
@@ -16,9 +16,10 @@ endfunction()
# vision examples
if (WITH_VISION_EXAMPLES)
add_fastdeploy_executable(vision ultralytics yolov5)
add_fastdeploy_executable(vision ppdet ppyoloe)
add_fastdeploy_executable(vision meituan yolov6)
add_fastdeploy_executable(vision wongkinyiu yolov7)
add_fastdeploy_executable(vision megvii yolox)
endif()
# other examples ...
# other examples ...

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel";
std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams";
std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml";
std::string img_path = "test.jpeg";
std::string vis_path = "vis.jpeg";
auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file);
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
return -1;
}
cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();
vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}

View File

@@ -17,7 +17,7 @@ from .fastdeploy_main import Frontend, Backend, FDDataType, TensorInfo, RuntimeO
from .fastdeploy_runtime import *
from . import fastdeploy_main as C
from . import vision
from .download import download
from .download import download, download_and_decompress
def TensorInfoStr(tensor_info):

View File

@@ -156,7 +156,7 @@ def decompress(fname):
def url2dir(url, path, rename=None):
full_name = download(url, path, rename, show_progress=True)
print("SDK is donwloaded, now extracting...")
print("File is donwloaded, now extracting...")
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0:
return decompress(full_name)

View File

@@ -64,6 +64,10 @@ class FASTDEPLOY_DECL FDLogger {
bool verbose_ = true;
};
#ifndef __REL_FILE__
#define __REL_FILE__ __FILE__
#endif
#define FDERROR \
FDLogger(true, "[ERROR]") \
<< __REL_FILE__ << "(" << __LINE__ << ")::" << __FUNCTION__ << "\t"

View File

@@ -16,6 +16,7 @@
#include "fastdeploy/core/config.h"
#ifdef ENABLE_VISION
#include "fastdeploy/vision/ppcls/model.h"
#include "fastdeploy/vision/ppdet/ppyoloe.h"
#include "fastdeploy/vision/ultralytics/yolov5.h"
#include "fastdeploy/vision/wongkinyiu/yolov7.h"
#include "fastdeploy/vision/meituan/yolov6.h"

View File

@@ -15,6 +15,7 @@ from __future__ import absolute_import
from . import evaluation
from . import ppcls
from . import ppdet
from . import ultralytics
from . import meituan
from . import megvii

View File

@@ -1,3 +1,16 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/ppcls/model.h"
#include "fastdeploy/vision/utils/utils.h"
@@ -135,6 +148,6 @@ bool Model::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
return true;
}
} // namespace ppcls
} // namespace vision
} // namespace fastdeploy
} // namespace ppcls
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,7 +1,21 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {

View File

@@ -14,7 +14,7 @@
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPpClsModel(pybind11::module& m) {
void BindPPCls(pybind11::module& m) {
auto ppcls_module = m.def_submodule("ppcls", "Module to deploy PaddleClas.");
pybind11::class_<vision::ppcls::Model, FastDeployModel>(ppcls_module, "Model")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,

View File

@@ -0,0 +1,39 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import fastdeploy_main as C
class PPYOLOE(FastDeployModel):
def __init__(self,
model_file,
params_file,
config_file,
backend_option=None,
model_format=Frontend.PADDLE):
super(PPYOLOE, self).__init__(backend_option)
assert model_format == Frontend.PADDLE, "PPYOLOE only support model format of Frontend.Paddle now."
self._model = C.vision.ppdet.PPYOLOE(model_file, params_file,
config_file, self._runtime_option,
model_format)
assert self.initialized, "PPYOLOE model initialize failed."
def predict(self, input_image, conf_threshold=0.5, nms_iou_threshold=0.7):
assert input_image is not None, "The input image data is None."
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)

View File

@@ -0,0 +1,32 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPDet(pybind11::module& m) {
auto ppdet_module =
m.def_submodule("ppdet", "Module to deploy PaddleDetection.");
pybind11::class_<vision::ppdet::PPYOLOE, FastDeployModel>(ppdet_module,
"PPYOLOE")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
Frontend>())
.def("predict", [](vision::ppdet::PPYOLOE& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::DetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,170 @@
#include "fastdeploy/vision/ppdet/ppyoloe.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
PPYOLOE::PPYOLOE(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const Frontend& model_format) {
config_file_ = config_file;
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
valid_gpu_backends = {Backend::ORT, Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool PPYOLOE::Initialize() {
if (!BuildPreprocessPipelineFromConfig()) {
std::cout << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
std::cout << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
std::cout << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
if (cfg["arch"].as<std::string>() != "YOLO") {
std::cout << "Require the arch of model is YOLO, but arch defined in "
"config file is "
<< cfg["arch"].as<std::string>() << "." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = op["is_scale"].as<bool>();
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size(),
"Require size of target_size be 2, but now it's " +
std::to_string(target_size.size()) + ".");
FDASSERT(!keep_ratio,
"Only support keep_ratio is false while deploy "
"PaddleDetection model.");
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else if (op_name == "Permute") {
processors_.push_back(std::make_shared<HWC2CHW>());
} else {
std::cout << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
return true;
}
bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
std::cout << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
outputs->resize(2);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));
// reshape to [1, c, h, w]
(*outputs)[0].shape.insert((*outputs)[0].shape.begin(), 1);
(*outputs)[1].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(1).name);
float* ptr = static_cast<float*>((*outputs)[1].MutableData());
ptr[0] = mat->Height() * 1.0 / mat->Height();
ptr[1] = mat->Width() * 1.0 / mat->Width();
return true;
}
bool PPYOLOE::Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result, float conf_threshold,
float nms_threshold) {
FDASSERT(infer_result[1].shape[0] == 1,
"Only support batch = 1 in FastDeploy now.");
int box_num = 0;
if (infer_result[1].dtype == FDDataType::INT32) {
box_num = *(static_cast<int32_t*>(infer_result[1].Data()));
} else if (infer_result[1].dtype == FDDataType::INT64) {
box_num = *(static_cast<int64_t*>(infer_result[1].Data()));
} else {
FDASSERT(
false,
"The output box_num of PPYOLOE model should be type of int32/int64.");
}
result->Reserve(box_num);
float* box_data = static_cast<float*>(infer_result[0].Data());
for (size_t i = 0; i < box_num; ++i) {
if (box_data[i * 6 + 1] < conf_threshold) {
continue;
}
result->label_ids.push_back(box_data[i * 6]);
result->scores.push_back(box_data[i * 6 + 1]);
result->boxes.emplace_back(
std::array<float, 4>{box_data[i * 6 + 2], box_data[i * 6 + 3],
box_data[i * 6 + 4] - box_data[i * 6 + 2],
box_data[i * 6 + 5] - box_data[i * 6 + 3]});
}
return true;
}
bool PPYOLOE::Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold, float iou_threshold) {
Mat mat(*im);
std::vector<FDTensor> processed_data;
if (!Preprocess(&mat, &processed_data)) {
FDERROR << "Failed to preprocess input data while using model:"
<< ModelName() << "." << std::endl;
return false;
}
std::vector<FDTensor> infer_result;
if (!Infer(processed_data, &infer_result)) {
FDERROR << "Failed to inference while using model:" << ModelName() << "."
<< std::endl;
return false;
}
if (!Postprocess(infer_result, result, conf_threshold, iou_threshold)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl;
return false;
}
return true;
}
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,44 @@
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace ppdet {
class FASTDEPLOY_DECL PPYOLOE : public FastDeployModel {
public:
PPYOLOE(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::PADDLE);
std::string ModelName() const { return "PaddleDetection/PPYOLOE"; }
virtual bool Initialize();
virtual bool BuildPreprocessPipelineFromConfig();
virtual bool Preprocess(Mat* mat, std::vector<FDTensor>* outputs);
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
DetectionResult* result, float conf_threshold,
float nms_threshold);
virtual bool Predict(cv::Mat* im, DetectionResult* result,
float conf_threshold = 0.5, float nms_threshold = 0.7);
private:
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
// PaddleDetection can export model without nms
// This flag will help us to handle the different
// situation
bool has_nms_;
};
} // namespace ppdet
} // namespace vision
} // namespace fastdeploy

View File

@@ -16,7 +16,8 @@
namespace fastdeploy {
void BindPpClsModel(pybind11::module& m);
void BindPPCls(pybind11::module& m);
void BindPPDet(pybind11::module& m);
void BindWongkinyiu(pybind11::module& m);
void BindUltralytics(pybind11::module& m);
void BindMeituan(pybind11::module& m);
@@ -41,13 +42,14 @@ void BindVision(pybind11::module& m) {
.def("__repr__", &vision::DetectionResult::Str)
.def("__str__", &vision::DetectionResult::Str);
BindPpClsModel(m);
BindPPCls(m);
BindPPDet(m);
BindUltralytics(m);
BindWongkinyiu(m);
BindMeituan(m);
BindMegvii(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif
#endif
}
} // namespace fastdeploy
} // namespace fastdeploy

View File

@@ -43,7 +43,7 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
}
std::string text = id + "," + score;
int font = cv::FONT_HERSHEY_SIMPLEX;
cv::Size text_size = cv::getTextSize(text, font, font_size, 0.5, nullptr);
cv::Size text_size = cv::getTextSize(text, font, font_size, 1, nullptr);
cv::Point origin;
origin.x = rect.x;
origin.y = rect.y;
@@ -52,7 +52,7 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result,
text_size.width, text_size.height);
cv::rectangle(*im, rect, rect_color, line_size);
cv::putText(*im, text, origin, font, font_size, cv::Scalar(255, 255, 255),
0.5);
1);
}
}

View File

@@ -0,0 +1,52 @@
# PaddleDetection/PPYOLOE部署示例
- 当前支持PaddleDetection版本为[release/2.4](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4)
本文档说明如何进行[PPYOLOE](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/ppyoloe)的快速部署推理。本目录结构如下
```
.
├── cpp # C++ 代码目录
│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件
│   ├── README.md # C++ 代码编译部署文档
│   └── ppyoloe.cc # C++ 示例代码
├── README.md # PPYOLOE 部署文档
└── ppyoloe.py # Python示例代码
```
## 安装FastDeploy
使用如下命令安装FastDeploy注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu`
```
# 安装fastdeploy-python工具
pip install fastdeploy-python
```
## Python部署
执行如下代码即会自动下载PPYOLOE模型和测试图片
```
python ppyoloe.py
```
执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下
```
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33
414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0
163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0
267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0
581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0
104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0
348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0
364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0
75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56
328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0
504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0
379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0
25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0
```
## 其它文档
- [C++部署](./cpp/README.md)
- [PPYOLOE API文档](./api.md)

View File

@@ -0,0 +1,74 @@
# PPYOLOE API说明
## Python API
### PPYOLOE类
```
fastdeploy.vision.ultralytics.PPYOLOE(model_file, params_file, config_file, runtime_option=None, model_format=fd.Frontend.PADDLE)
```
PPYOLOE模型加载和初始化需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 模型推理配置文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### predict函数
> ```
> PPYOLOE.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
> ```
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值当模型中包含nms处理时此参数自动无效
示例代码参考[ppyoloe.py](./ppyoloe.py)
## C++ API
### PPYOLOE类
```
fastdeploy::vision::ultralytics::PPYOLOE(
const string& model_file,
const string& params_file,
const string& config_file,
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX)
```
PPYOLOE模型加载和初始化需同时提供model_file和params_file, 当前仅支持model_format为Paddle格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 模型推理配置文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### Predict函数
> ```
> YOLOv5::Predict(cv::Mat* im, DetectionResult* result,
> float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5)
> ```
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度
> > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值(当模型中包含nms处理时此参数自动无效
示例代码参考[cpp/yolov5.cc](cpp/yolov5.cc)
## 其它API使用
- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md)

View File

@@ -0,0 +1,17 @@
PROJECT(ppyoloe_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.16)
# 在低版本ABI环境中通过如下代码进行兼容性编译
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(ppyoloe_demo ${PROJECT_SOURCE_DIR}/ppyoloe.cc)
# 添加FastDeploy库依赖
target_link_libraries(ppyoloe_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,39 @@
# 编译PPYOLOE示例
```
# 下载和解压预测库
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz
tar xvf fastdeploy-linux-x64-0.0.3.tgz
# 编译示例代码
mkdir build & cd build
cmake ..
make -j
# 下载模型和图片
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz
tar xvf ppyoloe_crn_l_300e_coco.tgz
wget https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg
# 执行
./ppyoloe_demo
```
执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示
```
DetectionResult: [xmin, ymin, xmax, ymax, score, label_id]
162.380249,132.057449, 463.178345, 413.167114, 0.962918, 33
414.914642,141.148666, 91.275269, 308.688293, 0.951003, 0
163.449234,129.669067, 35.253891, 135.111786, 0.900734, 0
267.232239,142.290436, 31.578918, 126.329773, 0.848709, 0
581.790833,179.027115, 30.893127, 135.484940, 0.837986, 0
104.407021,72.602615, 22.900627, 75.469055, 0.796468, 0
348.795380,70.122147, 18.806061, 85.829330, 0.785557, 0
364.118683,92.457428, 17.437622, 89.212891, 0.774282, 0
75.180283,192.470490, 41.898407, 55.552414, 0.712569, 56
328.133759,61.894299, 19.100616, 65.633575, 0.710519, 0
504.797760,181.732574, 107.740814, 248.115082, 0.708902, 0
379.063080,64.762360, 15.956146, 68.312546, 0.680725, 0
25.858747,186.564178, 34.958130, 56.007080, 0.580415, 0
```

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
std::string model_file = "ppyoloe_crn_l_300e_coco/model.pdmodel";
std::string params_file = "ppyoloe_crn_l_300e_coco/model.pdiparams";
std::string config_file = "ppyoloe_crn_l_300e_coco/infer_cfg.yml";
std::string img_path = "000000014439_640x640.jpg";
std::string vis_path = "vis.jpeg";
auto model = vis::ppdet::PPYOLOE(model_file, params_file, config_file);
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
return -1;
}
cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();
vis::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisDetection(&vis_im, res);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}

View File

@@ -0,0 +1,24 @@
import fastdeploy as fd
import cv2
# 下载模型和测试图片
model_url = "https://bj.bcebos.com/paddle2onnx/fastdeploy/models/ppdet/ppyoloe_crn_l_300e_coco.tgz"
test_jpg_url = "https://raw.githubusercontent.com/PaddlePaddle/PaddleDetection/release/2.4/demo/000000014439_640x640.jpg"
fd.download_and_decompress(model_url, ".")
fd.download(test_jpg_url, ".", show_progress=True)
# 加载模型
model = fd.vision.ppdet.PPYOLOE("ppyoloe_crn_l_300e_coco/model.pdmodel",
"ppyoloe_crn_l_300e_coco/model.pdiparams",
"ppyoloe_crn_l_300e_coco/infer_cfg.yml")
# 预测图片
im = cv2.imread("000000014439_640x640.jpg")
result = model.predict(im, conf_threshold=0.5)
# 可视化结果
fd.vision.visualize.vis_detection(im, result)
cv2.imwrite("vis_result.jpg", im)
# 输出预测结果
print(result)