From 30bb233db8e346836353ea146863bd7d231fcb8d Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 30 Aug 2022 15:01:20 +0800 Subject: [PATCH] [feature][vision] Add YOLOv7 End2End model with ORT NMS (#152) * [feature][cmake] enable build fastdeploy with examples * [feature][cmake] enable build fastdeploy with examples * [feature][vision] Add YOLOv7 End2End model with ORT NMS * [docs] update yolov7end2end_ort docs update yolov7end2end_ort docs * [docs] update yolov7end2end_ort examples docs update yolov7end2end_ort examples docs * [docs] update yolov7end2end_ort examples docs Co-authored-by: Jason --- csrc/fastdeploy/vision.h | 1 + .../detection/contrib/yolov7end2end_ort.cc | 249 ++++++++++++++++++ .../detection/contrib/yolov7end2end_ort.h | 92 +++++++ .../contrib/yolov7end2end_ort_pybind.cc | 41 +++ .../vision/detection/detection_pybind.cc | 2 + .../detection/yolov7end2end_ort/README.md | 41 +++ .../yolov7end2end_ort/cpp/CMakeLists.txt | 14 + .../detection/yolov7end2end_ort/cpp/README.md | 93 +++++++ .../detection/yolov7end2end_ort/cpp/infer.cc | 110 ++++++++ .../yolov7end2end_ort/python/README.md | 84 ++++++ .../yolov7end2end_ort/python/infer.py | 53 ++++ fastdeploy/vision/detection/__init__.py | 1 + .../detection/contrib/yolov7end2end_ort.py | 105 ++++++++ 13 files changed, 886 insertions(+) create mode 100644 csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.cc create mode 100644 csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.h create mode 100644 csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort_pybind.cc create mode 100644 examples/vision/detection/yolov7end2end_ort/README.md create mode 100644 examples/vision/detection/yolov7end2end_ort/cpp/CMakeLists.txt create mode 100644 examples/vision/detection/yolov7end2end_ort/cpp/README.md create mode 100644 examples/vision/detection/yolov7end2end_ort/cpp/infer.cc create mode 100644 examples/vision/detection/yolov7end2end_ort/python/README.md create mode 100644 examples/vision/detection/yolov7end2end_ort/python/infer.py create mode 100644 fastdeploy/vision/detection/contrib/yolov7end2end_ort.py diff --git a/csrc/fastdeploy/vision.h b/csrc/fastdeploy/vision.h index 227df7840..59c2ac522 100644 --- a/csrc/fastdeploy/vision.h +++ b/csrc/fastdeploy/vision.h @@ -23,6 +23,7 @@ #include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov6.h" #include "fastdeploy/vision/detection/contrib/yolov7.h" +#include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h" #include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h" diff --git a/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.cc b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.cc new file mode 100644 index 000000000..fea86df23 --- /dev/null +++ b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.cc @@ -0,0 +1,249 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +void YOLOv7End2EndORT::LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill, bool scale_up, int stride) { + float scale = + std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width()); + if (!scale_up) { + scale = std::min(scale, 1.0f); + } + + int resize_h = int(round(mat->Height() * scale)); + int resize_w = int(round(mat->Width() * scale)); + + int pad_w = size[0] - resize_w; + int pad_h = size[1] - resize_h; + if (_auto) { + pad_h = pad_h % stride; + pad_w = pad_w % stride; + } else if (scale_fill) { + pad_h = 0; + pad_w = 0; + resize_h = size[1]; + resize_w = size[0]; + } + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, color); + } +} + +YOLOv7End2EndORT::YOLOv7End2EndORT(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; + valid_gpu_backends = {Backend::ORT}; // NO TRT + } else { + valid_cpu_backends = {Backend::PDINFER}; + valid_gpu_backends = {Backend::PDINFER}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + if (custom_option.backend == Backend::TRT) { + FDWARNING << "Backend::TRT is not support for YOLOv7End2EndORT, " + << "will fallback to Backend::ORT." << std::endl; + } + initialized = Initialize(); +} + +bool YOLOv7End2EndORT::Initialize() { + // parameters for preprocess + size = {640, 640}; + padding_value = {114.0, 114.0, 114.0}; + is_mini_pad = false; + is_no_pad = false; + is_scale_up = false; + stride = 32; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } + return true; +} + +bool YOLOv7End2EndORT::Preprocess( + Mat* mat, FDTensor* output, + std::map>* im_info) { + float ratio = std::min(size[1] * 1.0f / static_cast(mat->Height()), + size[0] * 1.0f / static_cast(mat->Width())); + if (ratio != 1.0) { + int interp = cv::INTER_AREA; + if (ratio > 1.0) { + interp = cv::INTER_LINEAR; + } + int resize_h = int(mat->Height() * ratio); + int resize_w = int(mat->Width() * ratio); + Resize::Run(mat, resize_w, resize_h, -1, -1, interp); + } + YOLOv7End2EndORT::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, + is_scale_up, stride); + BGR2RGB::Run(mat); + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + Convert::Run(mat, alpha, beta); + + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool YOLOv7End2EndORT::Postprocess( + FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold) { + if (infer_result.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + // detected success without valid objects. + if (infer_result.shape[0] == 0) { + return true; + } + + result->Clear(); + result->Reserve(infer_result.shape[0]); + // (?,7) (batch_id,x0,y0,x1,y1,cls_id,score) after nms + float* data = static_cast(infer_result.Data()); + for (size_t i = 0; i < infer_result.shape[0]; ++i) { + const float* box_cls_ptr = data + (i * 7); + int64_t batch_id = static_cast(box_cls_ptr[0] + 0.5f); // 0,1, ... + FDASSERT(batch_id == 0, + "Only support batch=1 now, but found batch_id != 0."); + float confidence = box_cls_ptr[6]; + if (confidence <= conf_threshold) { + continue; + } + int32_t label_id = static_cast(box_cls_ptr[5] + 0.5f); + float x1 = box_cls_ptr[1]; + float y1 = box_cls_ptr[2]; + float x2 = box_cls_ptr[3]; + float y2 = box_cls_ptr[4]; + + result->boxes.emplace_back(std::array{x1, y1, x2, y2}); + result->label_ids.push_back(label_id); + result->scores.push_back(confidence); + } + + if (result->boxes.size() == 0) { + return true; + } + + // scale the boxes to the origin image shape + auto iter_out = im_info.find("output_shape"); + auto iter_ipt = im_info.find("input_shape"); + FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + float scale = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * scale) / 2.0f; + float pad_w = (out_w - ipt_w * scale) / 2.0f; + if (is_mini_pad) { + pad_h = static_cast(static_cast(pad_h) % stride); + pad_w = static_cast(static_cast(pad_w) % stride); + } + for (size_t i = 0; i < result->boxes.size(); ++i) { + int32_t label_id = (result->label_ids)[i]; + result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f); + result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f); + result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f); + result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + return true; +} + +bool YOLOv7End2EndORT::Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold) { + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } + + if (!Postprocess(output_tensors[0], result, im_info, conf_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + + return true; +} + +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.h b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.h new file mode 100644 index 000000000..ac2ba6aa6 --- /dev/null +++ b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { +namespace detection { + +class FASTDEPLOY_DECL YOLOv7End2EndORT : public FastDeployModel { + public: + YOLOv7End2EndORT(const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + virtual std::string ModelName() const { return "yolov7end2end_ort"; } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold = 0.25); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height) + std::vector size; + // padding value, size should be same with Channels + std::vector padding_value; + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad; + // while is_mini_pad = false and is_no_pad = true, will resize the image to + // the set size + bool is_no_pad; + // if is_scale_up is false, the input image only can be zoom out, the maximum + // resize scale cannot exceed 1.0 + bool is_scale_up; + // padding stride, for is_mini_pad + int stride; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + bool Postprocess(FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold); + + // 对图片进行LetterBox处理 + // mat 为读取到的原图 + // size 为输入模型的图像尺寸 + void LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill = false, bool scale_up = true, + int stride = 32); + + bool is_dynamic_input_; +}; +} // namespace detection +} // namespace vision +} // namespace fastdeploy diff --git a/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort_pybind.cc b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort_pybind.cc new file mode 100644 index 000000000..79794aaf2 --- /dev/null +++ b/csrc/fastdeploy/vision/detection/contrib/yolov7end2end_ort_pybind.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindYOLOv7End2EndORT(pybind11::module& m) { + pybind11::class_( + m, "YOLOv7End2EndORT") + .def(pybind11::init()) + .def("predict", + [](vision::detection::YOLOv7End2EndORT& self, pybind11::array& data, + float conf_threshold) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res, conf_threshold); + return res; + }) + .def_readwrite("size", &vision::detection::YOLOv7End2EndORT::size) + .def_readwrite("padding_value", + &vision::detection::YOLOv7End2EndORT::padding_value) + .def_readwrite("is_mini_pad", + &vision::detection::YOLOv7End2EndORT::is_mini_pad) + .def_readwrite("is_no_pad", + &vision::detection::YOLOv7End2EndORT::is_no_pad) + .def_readwrite("is_scale_up", + &vision::detection::YOLOv7End2EndORT::is_scale_up) + .def_readwrite("stride", &vision::detection::YOLOv7End2EndORT::stride); +} +} // namespace fastdeploy diff --git a/csrc/fastdeploy/vision/detection/detection_pybind.cc b/csrc/fastdeploy/vision/detection/detection_pybind.cc index a865dc11e..931f298a9 100644 --- a/csrc/fastdeploy/vision/detection/detection_pybind.cc +++ b/csrc/fastdeploy/vision/detection/detection_pybind.cc @@ -25,6 +25,7 @@ void BindYOLOv5(pybind11::module& m); void BindYOLOX(pybind11::module& m); void BindNanoDetPlus(pybind11::module& m); void BindPPDet(pybind11::module& m); +void BindYOLOv7End2EndORT(pybind11::module& m); void BindDetection(pybind11::module& m) { auto detection_module = @@ -38,5 +39,6 @@ void BindDetection(pybind11::module& m) { BindYOLOv5(detection_module); BindYOLOX(detection_module); BindNanoDetPlus(detection_module); + BindYOLOv7End2EndORT(detection_module); } } // namespace fastdeploy diff --git a/examples/vision/detection/yolov7end2end_ort/README.md b/examples/vision/detection/yolov7end2end_ort/README.md new file mode 100644 index 000000000..fdfd08019 --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/README.md @@ -0,0 +1,41 @@ +# YOLOv7End2EndORT 准备部署模型 + +YOLOv7End2EndORT 部署实现来自[YOLOv7](https://github.com/WongKinYiu/yolov7/tree/v0.1)分支代码,和[基于COCO的预训练模型](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)。注意,YOLOv7End2EndORT是专门用于推理YOLOv7中导出模型带[ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型请使用YOLOv7End2EndTRT进行推理。 + + - (1)[官方库](https://github.com/WongKinYiu/yolov7/releases/tag/v0.1)提供的*.pt通过[导出ONNX模型](#导出ONNX模型)操作后,可进行部署;*.trt和*.pose模型不支持部署; + - (2)自己数据训练的YOLOv7模型,按照[导出ONNX模型](#%E5%AF%BC%E5%87%BAONNX%E6%A8%A1%E5%9E%8B)操作后,参考[详细部署文档](#详细部署文档)完成部署。 + + +## 导出ONNX模型 + +```bash +# 下载yolov7模型文件 +wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt + +# 导出带ORT_NMS的onnx格式文件 (Tips: 对应 YOLOv7 release v0.1 代码) +python export.py --weights yolov7.pt --grid --end2end --simplify --topk-all 100 --iou-thres 0.65 --conf-thres 0.35 --img-size 640 640 --max-wh 640 +# 导出其他模型的命令类似 将yolov7.pt替换成 yolov7x.pt yolov7-d6.pt yolov7-w6.pt ... +``` + +## 下载预训练ONNX模型 + +为了方便开发者的测试,下面提供了YOLOv7End2EndORT导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库) +| 模型 | 大小 | 精度 | +|:---------------------------------------------------------------- |:----- |:----- | +| [yolov7-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-ort-nms.onnx) | 141MB | 51.4% | +| [yolov7x-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7x-end2end-ort-nms.onnx) | 273MB | 53.1% | +| [yolov7-w6-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-w6-end2end-ort-nms.onnx) | 269MB | 54.9% | +| [yolov7-e6-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6-end2end-ort-nms.onnx) | 372MB | 56.0% | +| [yolov7-d6-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-d6-end2end-ort-nms.onnx) | 511MB | 56.6% | +| [yolov7-e6e-end2end-ort-nms](https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-e6e-end2end-ort-nms.onnx) | 579MB | 56.8% | + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) + + +## 版本说明 + +- 本版本文档和代码基于[YOLOv7 0.1](https://github.com/WongKinYiu/yolov7/tree/v0.1) 编写 diff --git a/examples/vision/detection/yolov7end2end_ort/cpp/CMakeLists.txt b/examples/vision/detection/yolov7end2end_ort/cpp/CMakeLists.txt new file mode 100644 index 000000000..fea1a2888 --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.12) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/detection/yolov7end2end_ort/cpp/README.md b/examples/vision/detection/yolov7end2end_ort/cpp/README.md new file mode 100644 index 000000000..450269170 --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/cpp/README.md @@ -0,0 +1,93 @@ +# YOLOv7End2EndORT C++部署示例 + +本目录下提供`infer.cc`快速完成YOLOv7End2EndORT在CPU/GPU部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start) + +以Linux上推理为例,在本目录执行如下命令即可完成编译测试 + +```bash +mkdir build +cd build +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.2.0.tgz +tar xvf fastdeploy-linux-x64-0.2.0.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-gpu-0.2.0 +make -j +# 如果预编译库还没有支持该模型,请从develop分支源码编译最新的SDK + +#下载官方转换好的yolov7模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-ort-nms.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + + +# CPU推理 +./infer_demo yolov7-end2end-ort-nms.onnx 000000014439.jpg 0 +# GPU推理 +./infer_demo yolov7-end2end-ort-nms.onnx 000000014439.jpg 1 +# TensorRT + GPU 部署 (暂不支持 会回退到 ORT + GPU) +./infer_demo yolov7-end2end-ort-nms.onnx 000000014439.jpg 2 +``` + +运行完成可视化结果如下图所示 + +
+ image +
+ +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md) + +注意,YOLOv7End2EndORT是专门用于推理YOLOv7中导出模型带[ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型请使用YOLOv7End2EndTRT进行推理。 + +## YOLOv7End2EndORT C++接口 + +### YOLOv7End2EndORT 类 + +```c++ +fastdeploy::vision::detection::YOLOv7End2EndORT( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` + +YOLOv7End2EndORT 模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ```c++ +> YOLOv7End2EndORT::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.25) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) +> > * **conf_threshold**: 检测框置信度过滤阈值,但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值,因此该参数只有在大于已经指定的阈值时才会有效。 + +### 类成员变量 +#### 预处理参数 +用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 + +> > * **size**(vector<int>): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640] +> > * **padding_value**(vector<float>): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114] +> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=ture` 表示不使用填充的方式,默认值为`is_no_pad=false` +> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=false` +> > * **stride**(int): 配合`stris_mini_pad`成员变量使用, 默认值为`stride=32` + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/yolov7end2end_ort/cpp/infer.cc b/examples/vision/detection/yolov7end2end_ort/cpp/infer.cc new file mode 100644 index 000000000..a0e70544a --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/cpp/infer.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::detection::YOLOv7End2EndORT(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = + fastdeploy::vision::detection::YOLOv7End2EndORT(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 640, 640}); + auto model = + fastdeploy::vision::detection::YOLOv7End2EndORT(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + std::cout << res.Str() << std::endl; + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./yolov7-end2end-ort.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/detection/yolov7end2end_ort/python/README.md b/examples/vision/detection/yolov7end2end_ort/python/README.md new file mode 100644 index 000000000..3313d8e41 --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/python/README.md @@ -0,0 +1,84 @@ +# YOLOv7End2EndORT Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/the%20software%20and%20hardware%20requirements.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start) + +本目录下提供`infer.py`快速完成YOLOv7End2End在CPU/GPU部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/yolov7end2end_ort/python/ +# 如果预编译的Python wheel包还没有支持该模型,请从develop分支源码编译最新python包进行安装 + +#下载yolov7模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7-end2end-ort-nms.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# CPU推理 +python infer.py --model yolov7-end2end-ort-nms.onnx --image 000000014439.jpg --device cpu +# GPU推理 +python infer.py --model yolov7-end2end-ort-nms.onnx --image 000000014439.jpg --device gpu +# TensorRT + GPU推理 (暂不支持 会回退到 ORT + GPU) +python infer.py --model yolov7-end2end-ort-nms.onnx --image 000000014439.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 + +
+ image +
+ +注意,YOLOv7End2EndORT是专门用于推理YOLOv7中导出模型带[ORT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L87) 版本的End2End模型,不带nms的模型推理请使用YOLOv7类,而 [TRT_NMS](https://github.com/WongKinYiu/yolov7/blob/main/models/experimental.py#L111) 版本的End2End模型请使用YOLOv7End2EndTRT进行推理。 + +## YOLOv7End2EndORT Python接口 + +```python +fastdeploy.vision.detection.YOLOv7End2EndORT(model_file, params_file=None, runtime_option=None, model_format=Frontend.ONNX) +``` + +YOLOv7End2EndORT模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为ONNX + +### predict函数 + +> ```python +> YOLOv7End2EndORT.predict(image_data, conf_threshold=0.25) +> ``` +> +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值,但由于YOLOv7 End2End的模型在导出成ONNX时已经指定了score阈值,因此该参数只有在大于已经指定的阈值时才会有效。 + +> **返回** +> +> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +### 类成员属性 +#### 预处理参数 +用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 + +> > * **size**(list[int]): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640] +> > * **padding_value**(list[float]): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[114, 114, 114] +> > * **is_no_pad**(bool): 通过此参数让图片是否通过填充的方式进行resize, `is_no_pad=True` 表示不使用填充的方式,默认值为`is_no_pad=False` +> > * **is_mini_pad**(bool): 通过此参数可以将resize之后图像的宽高这是为最接近`size`成员变量的值, 并且满足填充的像素大小是可以被`stride`成员变量整除的。默认值为`is_mini_pad=False` +> > * **stride**(int): 配合`stris_mini_padide`成员变量使用, 默认值为`stride=32` + + + +## 其它文档 + +- [YOLOv7End2EndORT 模型介绍](..) +- [YOLOv7End2EndORT C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) diff --git a/examples/vision/detection/yolov7end2end_ort/python/infer.py b/examples/vision/detection/yolov7end2end_ort/python/infer.py new file mode 100644 index 000000000..2b812b71a --- /dev/null +++ b/examples/vision/detection/yolov7end2end_ort/python/infer.py @@ -0,0 +1,53 @@ +import fastdeploy as fd +import cv2 + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of yolov7 end2end onnx model.") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("images", [1, 3, 640, 640]) + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.detection.YOLOv7End2EndORT( + args.model, runtime_option=runtime_option) + +# 预测图片检测结果 +im = cv2.imread(args.image) +result = model.predict(im.copy()) +print(result) + +# 预测结果可视化 +vis_im = fd.vision.vis_detection(im, result) +cv2.imwrite("visualized_result.jpg", vis_im) +print("Visualized result save in ./visualized_result.jpg") diff --git a/fastdeploy/vision/detection/__init__.py b/fastdeploy/vision/detection/__init__.py index 869b47ee5..7e41044ba 100644 --- a/fastdeploy/vision/detection/__init__.py +++ b/fastdeploy/vision/detection/__init__.py @@ -21,4 +21,5 @@ from .contrib.yolox import YOLOX from .contrib.yolov5 import YOLOv5 from .contrib.yolov5lite import YOLOv5Lite from .contrib.yolov6 import YOLOv6 +from .contrib.yolov7end2end_ort import YOLOv7End2EndORT from .ppdet import PPYOLOE, PPYOLO, PPYOLOv2, PaddleYOLOX, PicoDet, FasterRCNN, YOLOv3 diff --git a/fastdeploy/vision/detection/contrib/yolov7end2end_ort.py b/fastdeploy/vision/detection/contrib/yolov7end2end_ort.py new file mode 100644 index 000000000..e937d368d --- /dev/null +++ b/fastdeploy/vision/detection/contrib/yolov7end2end_ort.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, Frontend +from .... import c_lib_wrap as C + + +class YOLOv7End2EndORT(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(YOLOv7End2EndORT, self).__init__(runtime_option) + + self._model = C.vision.detection.YOLOv7End2EndORT( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "YOLOv7End2End initialize failed." + + def predict(self, input_image, conf_threshold=0.25): + return self._model.predict(input_image, conf_threshold) + + # 一些跟模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @property + def padding_value(self): + return self._model.padding_value + + @property + def is_no_pad(self): + return self._model.is_no_pad + + @property + def is_mini_pad(self): + return self._model.is_mini_pad + + @property + def is_scale_up(self): + return self._model.is_scale_up + + @property + def stride(self): + return self._model.stride + + @size.setter + def size(self, wh): + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @padding_value.setter + def padding_value(self, value): + assert isinstance( + value, + list), "The value to set `padding_value` must be type of list." + self._model.padding_value = value + + @is_no_pad.setter + def is_no_pad(self, value): + assert isinstance( + value, bool), "The value to set `is_no_pad` must be type of bool." + self._model.is_no_pad = value + + @is_mini_pad.setter + def is_mini_pad(self, value): + assert isinstance( + value, + bool), "The value to set `is_mini_pad` must be type of bool." + self._model.is_mini_pad = value + + @is_scale_up.setter + def is_scale_up(self, value): + assert isinstance( + value, + bool), "The value to set `is_scale_up` must be type of bool." + self._model.is_scale_up = value + + @stride.setter + def stride(self, value): + assert isinstance( + value, int), "The value to set `stride` must be type of int." + self._model.stride = value