diff --git a/.gitignore b/.gitignore index 39783b883..51f2f2ed8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,4 @@ fastdeploy.egg-info .setuptools-cmake-build fastdeploy/version.py fastdeploy/LICENSE* -fastdeploy/ThirdPartyNotices* \ No newline at end of file +fastdeploy/ThirdPartyNotices* diff --git a/examples/vision/wongkinyiu_yolor.cc b/examples/vision/wongkinyiu_yolor.cc new file mode 100644 index 000000000..abdca2b7f --- /dev/null +++ b/examples/vision/wongkinyiu_yolor.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "../resources/models/yolor.onnx"; + std::string img_path = "../resources/images/horses.jpg"; + std::string vis_path = "../resources/outputs/wongkinyiu_yolor_vis_result.jpg"; + + auto model = vis::wongkinyiu::YOLOR(model_file); + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/utils/utils.h b/fastdeploy/utils/utils.h index 931208426..b57a27f80 100644 --- a/fastdeploy/utils/utils.h +++ b/fastdeploy/utils/utils.h @@ -26,10 +26,10 @@ #define FASTDEPLOY_DECL __declspec(dllexport) #else #define FASTDEPLOY_DECL __declspec(dllimport) -#endif // FASTDEPLOY_LIB +#endif // FASTDEPLOY_LIB #else #define FASTDEPLOY_DECL __attribute__((visibility("default"))) -#endif // _WIN32 +#endif // _WIN32 namespace fastdeploy { @@ -42,7 +42,8 @@ class FASTDEPLOY_DECL FDLogger { } explicit FDLogger(bool verbose, const std::string& prefix = "[FastDeploy]"); - template FDLogger& operator<<(const T& val) { + template + FDLogger& operator<<(const T& val) { if (!verbose_) { return *this; } @@ -72,10 +73,14 @@ class FASTDEPLOY_DECL FDLogger { FDLogger(true, "[ERROR]") \ << __REL_FILE__ << "(" << __LINE__ << ")::" << __FUNCTION__ << "\t" -#define FDASSERT(condition, message) \ - if (!(condition)) { \ - FDERROR << message << std::endl; \ - std::abort(); \ +#define FDERROR \ + FDLogger(true, "[ERROR]") << __REL_FILE__ << "(" << __LINE__ \ + << ")::" << __FUNCTION__ << "\t" + +#define FDASSERT(condition, message) \ + if (!(condition)) { \ + FDERROR << message << std::endl; \ + std::abort(); \ } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index cafe310c7..68c0881ca 100644 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -15,12 +15,13 @@ #include "fastdeploy/core/config.h" #ifdef ENABLE_VISION +#include "fastdeploy/vision/megvii/yolox.h" +#include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" #include "fastdeploy/vision/ultralytics/yolov5.h" +#include "fastdeploy/vision/wongkinyiu/yolor.h" #include "fastdeploy/vision/wongkinyiu/yolov7.h" -#include "fastdeploy/vision/meituan/yolov6.h" -#include "fastdeploy/vision/megvii/yolox.h" #endif #include "fastdeploy/vision/visualize/visualize.h" diff --git a/fastdeploy/vision/ppcls/model.h b/fastdeploy/vision/ppcls/model.h index 265f92d32..71800a7d7 100644 --- a/fastdeploy/vision/ppcls/model.h +++ b/fastdeploy/vision/ppcls/model.h @@ -46,6 +46,6 @@ class FASTDEPLOY_DECL Model : public FastDeployModel { std::vector> processors_; std::string config_file_; }; -} // namespace ppcls -} // namespace vision -} // namespace fastdeploy +} // namespace ppcls +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/ppcls/ppcls_pybind.cc b/fastdeploy/vision/ppcls/ppcls_pybind.cc index 1abc0b2b7..10ff5ee10 100644 --- a/fastdeploy/vision/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/ppcls/ppcls_pybind.cc @@ -27,4 +27,4 @@ void BindPPCls(pybind11::module& m) { return res; }); } -} // namespace fastdeploy +} // namespace fastdeploy diff --git a/fastdeploy/vision/visualize/detection.cc b/fastdeploy/vision/visualize/detection.cc index 5b5538bff..6d6007244 100644 --- a/fastdeploy/vision/visualize/detection.cc +++ b/fastdeploy/vision/visualize/detection.cc @@ -56,6 +56,6 @@ void Visualize::VisDetection(cv::Mat* im, const DetectionResult& result, } } -} // namespace vision -} // namespace fastdeploy +} // namespace vision +} // namespace fastdeploy #endif diff --git a/fastdeploy/vision/wongkinyiu/__init__.py b/fastdeploy/vision/wongkinyiu/__init__.py index 542389e20..3c77e8589 100644 --- a/fastdeploy/vision/wongkinyiu/__init__.py +++ b/fastdeploy/vision/wongkinyiu/__init__.py @@ -114,3 +114,101 @@ class YOLOv7(FastDeployModel): assert isinstance( value, float), "The value to set `max_wh` must be type of float." self._model.max_wh = value + + +class YOLOR(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(YOLOR, self).__init__(runtime_option) + + self._model = C.vision.wongkinyiu.YOLOR( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "YOLOR initialize failed." + + def predict(self, input_image, conf_threshold=0.25, nms_iou_threshold=0.5): + return self._model.predict(input_image, conf_threshold, + nms_iou_threshold) + + # 一些跟YOLOR模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @property + def padding_value(self): + return self._model.padding_value + + @property + def is_no_pad(self): + return self._model.is_no_pad + + @property + def is_mini_pad(self): + return self._model.is_mini_pad + + @property + def is_scale_up(self): + return self._model.is_scale_up + + @property + def stride(self): + return self._model.stride + + @property + def max_wh(self): + return self._model.max_wh + + @size.setter + def size(self, wh): + assert isinstance(wh, [list, tuple]),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @padding_value.setter + def padding_value(self, value): + assert isinstance( + value, + list), "The value to set `padding_value` must be type of list." + self._model.padding_value = value + + @is_no_pad.setter + def is_no_pad(self, value): + assert isinstance( + value, bool), "The value to set `is_no_pad` must be type of bool." + self._model.is_no_pad = value + + @is_mini_pad.setter + def is_mini_pad(self, value): + assert isinstance( + value, + bool), "The value to set `is_mini_pad` must be type of bool." + self._model.is_mini_pad = value + + @is_scale_up.setter + def is_scale_up(self, value): + assert isinstance( + value, + bool), "The value to set `is_scale_up` must be type of bool." + self._model.is_scale_up = value + + @stride.setter + def stride(self, value): + assert isinstance( + value, int), "The value to set `stride` must be type of int." + self._model.stride = value + + @max_wh.setter + def max_wh(self, value): + assert isinstance( + value, float), "The value to set `max_wh` must be type of float." + self._model.max_wh = value diff --git a/fastdeploy/vision/wongkinyiu/wongkinyiu_pybind.cc b/fastdeploy/vision/wongkinyiu/wongkinyiu_pybind.cc index 4a10f47a7..6bde2a184 100644 --- a/fastdeploy/vision/wongkinyiu/wongkinyiu_pybind.cc +++ b/fastdeploy/vision/wongkinyiu/wongkinyiu_pybind.cc @@ -17,7 +17,7 @@ namespace fastdeploy { void BindWongkinyiu(pybind11::module& m) { auto wongkinyiu_module = - m.def_submodule("wongkinyiu", "https://github.com/WongKinYiu/yolov7"); + m.def_submodule("wongkinyiu", "https://github.com/WongKinYiu"); pybind11::class_( wongkinyiu_module, "YOLOv7") .def(pybind11::init()) @@ -37,5 +37,24 @@ void BindWongkinyiu(pybind11::module& m) { .def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOv7::is_scale_up) .def_readwrite("stride", &vision::wongkinyiu::YOLOv7::stride) .def_readwrite("max_wh", &vision::wongkinyiu::YOLOv7::max_wh); + + pybind11::class_( + wongkinyiu_module, "YOLOR") + .def(pybind11::init()) + .def("predict", + [](vision::wongkinyiu::YOLOR& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::DetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }) + .def_readwrite("size", &vision::wongkinyiu::YOLOR::size) + .def_readwrite("padding_value", &vision::wongkinyiu::YOLOR::padding_value) + .def_readwrite("is_mini_pad", &vision::wongkinyiu::YOLOR::is_mini_pad) + .def_readwrite("is_no_pad", &vision::wongkinyiu::YOLOR::is_no_pad) + .def_readwrite("is_scale_up", &vision::wongkinyiu::YOLOR::is_scale_up) + .def_readwrite("stride", &vision::wongkinyiu::YOLOR::stride) + .def_readwrite("max_wh", &vision::wongkinyiu::YOLOR::max_wh); } } // namespace fastdeploy diff --git a/fastdeploy/vision/wongkinyiu/yolor.cc b/fastdeploy/vision/wongkinyiu/yolor.cc new file mode 100644 index 000000000..5cf9d6cb8 --- /dev/null +++ b/fastdeploy/vision/wongkinyiu/yolor.cc @@ -0,0 +1,243 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/wongkinyiu/yolor.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace wongkinyiu { + +void YOLOR::LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill, bool scale_up, int stride) { + float scale = + std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width()); + if (!scale_up) { + scale = std::min(scale, 1.0f); + } + + int resize_h = int(round(mat->Height() * scale)); + int resize_w = int(round(mat->Width() * scale)); + + int pad_w = size[0] - resize_w; + int pad_h = size[1] - resize_h; + if (_auto) { + pad_h = pad_h % stride; + pad_w = pad_w % stride; + } else if (scale_fill) { + pad_h = 0; + pad_w = 0; + resize_h = size[1]; + resize_w = size[0]; + } + Resize::Run(mat, resize_w, resize_h); + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, color); + } +} + +YOLOR::YOLOR(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 + valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool YOLOR::Initialize() { + // parameters for preprocess + size = {640, 640}; + padding_value = {114.0, 114.0, 114.0}; + is_mini_pad = false; + is_no_pad = false; + is_scale_up = false; + stride = 32; + max_wh = 7680.0; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool YOLOR::Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { + // process after image load + double ratio = (size[0] * 1.0) / std::max(static_cast(mat->Height()), + static_cast(mat->Width())); + if (ratio != 1.0) { + int interp = cv::INTER_AREA; + if (ratio > 1.0) { + interp = cv::INTER_LINEAR; + } + int resize_h = int(mat->Height() * ratio); + int resize_w = int(mat->Width() * ratio); + Resize::Run(mat, resize_w, resize_h, -1, -1, interp); + } + // yolor's preprocess steps + // 1. letterbox + // 2. BGR->RGB + // 3. HWC->CHW + YOLOR::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, + is_scale_up, stride); + BGR2RGB::Run(mat); + Normalize::Run(mat, std::vector(mat->Channels(), 0.0), + std::vector(mat->Channels(), 1.0)); + + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool YOLOR::Postprocess( + FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold) { + FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now."); + result->Clear(); + result->Reserve(infer_result.shape[1]); + if (infer_result.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + float* data = static_cast(infer_result.Data()); + for (size_t i = 0; i < infer_result.shape[1]; ++i) { + int s = i * infer_result.shape[2]; + float confidence = data[s + 4]; + float* max_class_score = + std::max_element(data + s + 5, data + s + infer_result.shape[2]); + confidence *= (*max_class_score); + // filter boxes by conf_threshold + if (confidence <= conf_threshold) { + continue; + } + int32_t label_id = std::distance(data + s + 5, max_class_score); + // convert from [x, y, w, h] to [x1, y1, x2, y2] + result->boxes.emplace_back(std::array{ + data[s] - data[s + 2] / 2.0f + label_id * max_wh, + data[s + 1] - data[s + 3] / 2.0f + label_id * max_wh, + data[s + 0] + data[s + 2] / 2.0f + label_id * max_wh, + data[s + 1] + data[s + 3] / 2.0f + label_id * max_wh}); + result->label_ids.push_back(label_id); + result->scores.push_back(confidence); + } + utils::NMS(result, nms_iou_threshold); + + // scale the boxes to the origin image shape + auto iter_out = im_info.find("output_shape"); + auto iter_ipt = im_info.find("input_shape"); + FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + float scale = std::min(out_h / ipt_h, out_w / ipt_w); + for (size_t i = 0; i < result->boxes.size(); ++i) { + float pad_h = (out_h - ipt_h * scale) / 2; + float pad_w = (out_w - ipt_w * scale) / 2; + int32_t label_id = (result->label_ids)[i]; + // clip box + result->boxes[i][0] = result->boxes[i][0] - max_wh * label_id; + result->boxes[i][1] = result->boxes[i][1] - max_wh * label_id; + result->boxes[i][2] = result->boxes[i][2] - max_wh * label_id; + result->boxes[i][3] = result->boxes[i][3] - max_wh * label_id; + result->boxes[i][0] = std::max((result->boxes[i][0] - pad_w) / scale, 0.0f); + result->boxes[i][1] = std::max((result->boxes[i][1] - pad_h) / scale, 0.0f); + result->boxes[i][2] = std::max((result->boxes[i][2] - pad_w) / scale, 0.0f); + result->boxes[i][3] = std::max((result->boxes[i][3] - pad_h) / scale, 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h); + } + return true; +} + +bool YOLOR::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, + float nms_iou_threshold) { +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_START(0) +#endif + + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(0, "Preprocess") + TIMERECORD_START(1) +#endif + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(1, "Inference") + TIMERECORD_START(2) +#endif + + if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, + nms_iou_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(2, "Postprocess") +#endif + return true; +} + +} // namespace wongkinyiu +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/wongkinyiu/yolor.h b/fastdeploy/vision/wongkinyiu/yolor.h new file mode 100644 index 000000000..69f5ea876 --- /dev/null +++ b/fastdeploy/vision/wongkinyiu/yolor.h @@ -0,0 +1,95 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { +namespace wongkinyiu { + +class FASTDEPLOY_DECL YOLOR : public FastDeployModel { + public: + // 当model_format为ONNX时,无需指定params_file + // 当model_format为Paddle时,则需同时指定model_file & params_file + YOLOR(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + virtual std::string ModelName() const { return "WongKinYiu/yolor"; } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, DetectionResult* result, + float conf_threshold = 0.25, + float nms_iou_threshold = 0.5); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height) + std::vector size; + // padding value, size should be same with Channels + std::vector padding_value; + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad; + // while is_mini_pad = false and is_no_pad = true, will resize the image to + // the set size + bool is_no_pad; + // if is_scale_up is false, the input image only can be zoom out, the maximum + // resize scale cannot exceed 1.0 + bool is_scale_up; + // padding stride, for is_mini_pad + int stride; + // for offseting the boxes by classes when using NMS + float max_wh; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* outputs, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + // nms_iou_threshold 后处理时NMS设定的iou阈值 + bool Postprocess(FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + // 对图片进行LetterBox处理 + // mat 为读取到的原图 + // size 为输入模型的图像尺寸 + void LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill = false, bool scale_up = true, + int stride = 32); +}; +} // namespace wongkinyiu +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/wongkinyiu/yolov7.cc b/fastdeploy/vision/wongkinyiu/yolov7.cc index 248718a69..532f55294 100644 --- a/fastdeploy/vision/wongkinyiu/yolov7.cc +++ b/fastdeploy/vision/wongkinyiu/yolov7.cc @@ -20,9 +20,9 @@ namespace fastdeploy { namespace vision { namespace wongkinyiu { -void LetterBox(Mat* mat, const std::vector& size, - const std::vector& color, bool _auto, - bool scale_fill = false, bool scale_up = true, int stride = 32) { +void YOLOv7::LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill, bool scale_up, int stride) { float scale = std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width()); if (!scale_up) { @@ -107,8 +107,8 @@ bool YOLOv7::Preprocess(Mat* mat, FDTensor* output, // 1. letterbox // 2. BGR->RGB // 3. HWC->CHW - LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, is_scale_up, - stride); + YOLOv7::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, + is_scale_up, stride); BGR2RGB::Run(mat); Normalize::Run(mat, std::vector(mat->Channels(), 0.0), std::vector(mat->Channels(), 1.0)); diff --git a/fastdeploy/vision/wongkinyiu/yolov7.h b/fastdeploy/vision/wongkinyiu/yolov7.h index 75cab34de..c494754f0 100644 --- a/fastdeploy/vision/wongkinyiu/yolov7.h +++ b/fastdeploy/vision/wongkinyiu/yolov7.h @@ -70,7 +70,7 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { // FDTensor为预处理后的Tensor数据,传给后端进行推理 // im_info为预处理过程保存的数据,在后处理中需要用到 bool Preprocess(Mat* mat, FDTensor* outputs, - std::map>* im_info); + std::map>* im_info); // 后端推理结果后处理,输出给用户 // infer_result 为后端推理后的输出Tensor @@ -78,10 +78,17 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { // im_info 为预处理记录的信息,后处理用于还原box // conf_threshold 后处理时过滤box的置信度阈值 // nms_iou_threshold 后处理时NMS设定的iou阈值 - bool Postprocess( - FDTensor& infer_result, DetectionResult* result, - const std::map>& im_info, - float conf_threshold, float nms_iou_threshold); + bool Postprocess(FDTensor& infer_result, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + // 对图片进行LetterBox处理 + // mat 为读取到的原图 + // size 为输入模型的图像尺寸 + void LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill = false, bool scale_up = true, + int stride = 32); }; } // namespace wongkinyiu } // namespace vision diff --git a/model_zoo/vision/yolor/README.md b/model_zoo/vision/yolor/README.md new file mode 100644 index 000000000..358e62bbe --- /dev/null +++ b/model_zoo/vision/yolor/README.md @@ -0,0 +1,66 @@ +# 编译YOLOR示例 + +当前支持模型版本为:[YOLOR weights](https://github.com/WongKinYiu/yolor/releases/tag/weights) +(tips: 如果使用 `git clone` 的方式下载仓库代码,请将分支切换(checkout)到 `paper` 分支). + +本文档说明如何进行[YOLOR](https://github.com/WongKinYiu/yolor)的快速部署推理。本目录结构如下 + +``` +. +├── cpp +│   ├── CMakeLists.txt +│   ├── README.md +│   └── yolor.cc +├── README.md +└── yolor.py +``` + +## 获取ONNX文件 + +- 手动获取 + + 访问[YOLOR](https://github.com/WongKinYiu/yolor)官方github库,按照指引下载安装,下载`yolor.pt` 模型,利用 `models/export.py` 得到`onnx`格式文件。如果您导出的`onnx`模型出现精度不达标或者是数据维度的问题,可以参考[yolor#32](https://github.com/WongKinYiu/yolor/issues/32)的解决办法 + + ``` + #下载yolor模型文件 + wget https://github.com/WongKinYiu/yolor/releases/download/weights/yolor-d6-paper-570.pt + + # 导出onnx格式文件 + python models/export.py --weights PATH/TO/yolor-xx-xx-xx.pt --img-size 640 + + # 移动onnx文件到demo目录 + cp PATH/TO/yolor.onnx PATH/TO/model_zoo/vision/yolor/ + ``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` + +``` +# 安装fastdeploy-python工具 +pip install fastdeploy-python + +# 安装vision-cpu模块 +fastdeploy install vision-cpu +``` +## Python部署 + +执行如下代码即会自动下载测试图片 +``` +python yolor.py +``` + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +0.000000,185.201431, 315.673126, 410.071594, 0.959289, 17 +433.802826,211.603455, 595.489319, 346.425537, 0.952615, 17 +230.446854,195.618805, 418.365479, 362.712128, 0.884253, 17 +336.545624,208.555618, 457.704315, 323.543152, 0.788450, 17 +0.896423,183.936996, 154.788727, 304.916412, 0.672804, 17 +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [YOLOR API文档](./api.md) diff --git a/model_zoo/vision/yolor/api.md b/model_zoo/vision/yolor/api.md new file mode 100644 index 000000000..b1e5be889 --- /dev/null +++ b/model_zoo/vision/yolor/api.md @@ -0,0 +1,71 @@ +# YOLOR API说明 + +## Python API + +### YOLOR类 +``` +fastdeploy.vision.wongkinyiu.YOLOR(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX) +``` +YOLOR模型加载和初始化,当model_format为`fd.Frontend.ONNX`时,只需提供model_file,如`yolor.onnx`;当model_format为`fd.Frontend.PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> YOLOR.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +示例代码参考[yolor.py](./yolor.py) + + +## C++ API + +### YOLOR类 +``` +fastdeploy::vision::wongkinyiu::YOLOR( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +YOLOR模型加载和初始化,当model_format为`Frontend::ONNX`时,只需提供model_file,如`yolor.onnx`;当model_format为`Frontend::PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> YOLOR::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.25, +> float nms_iou_threshold = 0.5) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +示例代码参考[cpp/yolor.cc](cpp/yolor.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/yolor/cpp/CMakeLists.txt b/model_zoo/vision/yolor/cpp/CMakeLists.txt new file mode 100644 index 000000000..18248b845 --- /dev/null +++ b/model_zoo/vision/yolor/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(yolor_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(yolor_demo ${PROJECT_SOURCE_DIR}/yolor.cc) +# 添加FastDeploy库依赖 +target_link_libraries(yolor_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/yolor/cpp/README.md b/model_zoo/vision/yolor/cpp/README.md new file mode 100644 index 000000000..d06bbe300 --- /dev/null +++ b/model_zoo/vision/yolor/cpp/README.md @@ -0,0 +1,53 @@ +# 编译YOLOR示例 + +当前支持模型版本为:[YOLOR weights](https://github.com/WongKinYiu/yolor/releases/tag/weights) +(tips: 如果使用 `git clone` 的方式下载仓库代码,请将分支切换(checkout)到 `paper` 分支). +## 获取ONNX文件 + +- 手动获取 + + 访问[YOLOR](https://github.com/WongKinYiu/yolor)官方github库,按照指引下载安装,下载`yolor.pt` 模型,利用 `models/export.py` 得到`onnx`格式文件。如果您导出的`onnx`模型出现精度不达标或者是数据维度的问题,可以参考[yolor#32](https://github.com/WongKinYiu/yolor/issues/32)的解决办法 + + ``` + #下载yolor模型文件 + wget https://github.com/WongKinYiu/yolor/releases/download/weights/yolor-d6-paper-570.pt + + # 导出onnx格式文件 + python models/export.py --weights PATH/TO/yolor-xx-xx-xx.pt --img-size 640 + + # 移动onnx文件到demo目录 + cp PATH/TO/yolor.onnx PATH/TO/model_zoo/vision/yolor/ + ``` + + +## 运行demo + +``` +# 下载和解压预测库 +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz + +# 编译示例代码 +mkdir build & cd build +cmake .. +make -j + +# 移动onnx文件到demo目录 +cp PATH/TO/yolor.onnx PATH/TO/model_zoo/vision/yolor/cpp/build/ + +# 下载图片 +wget https://raw.githubusercontent.com/WongKinYiu/yolor/paper/inference/images/horses.jpg + +# 执行 +./yolor_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +DetectionResult: [xmin, ymin, xmax, ymax, score, label_id] +0.000000,185.201431, 315.673126, 410.071594, 0.959289, 17 +433.802826,211.603455, 595.489319, 346.425537, 0.952615, 17 +230.446854,195.618805, 418.365479, 362.712128, 0.884253, 17 +336.545624,208.555618, 457.704315, 323.543152, 0.788450, 17 +0.896423,183.936996, 154.788727, 304.916412, 0.672804, 17 +``` diff --git a/model_zoo/vision/yolor/cpp/yolor.cc b/model_zoo/vision/yolor/cpp/yolor.cc new file mode 100644 index 000000000..db194583f --- /dev/null +++ b/model_zoo/vision/yolor/cpp/yolor.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + auto model = vis::wongkinyiu::YOLOR("yolor.onnx"); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + cv::Mat im = cv::imread("horses.jpg"); + cv::Mat vis_im = im.clone(); + + vis::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisDetection(&vis_im, res); + cv::imwrite("vis_result.jpg", vis_im); + return 0; +} diff --git a/model_zoo/vision/yolor/yolor.py b/model_zoo/vision/yolor/yolor.py new file mode 100644 index 000000000..56d3f9689 --- /dev/null +++ b/model_zoo/vision/yolor/yolor.py @@ -0,0 +1,21 @@ +import fastdeploy as fd +import cv2 + +# 下载模型和测试图片 +test_jpg_url = "https://raw.githubusercontent.com/WongKinYiu/yolor/paper/inference/images/horses.jpg" +fd.download(test_jpg_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.wongkinyiu.YOLOR("yolor.onnx") + +# 预测图片 +im = cv2.imread("horses.jpg") +result = model.predict(im, conf_threshold=0.25, nms_iou_threshold=0.5) + +# 可视化结果 +fd.vision.visualize.vis_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) +print(model.runtime_option) diff --git a/model_zoo/vision/yolov7/README.md b/model_zoo/vision/yolov7/README.md index 2bb13ce45..8b2f06d76 100644 --- a/model_zoo/vision/yolov7/README.md +++ b/model_zoo/vision/yolov7/README.md @@ -27,10 +27,10 @@ wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt # 导出onnx格式文件 - python models/export.py --grid --dynamic --weights PATH/TO/yolo7.pt + python models/export.py --grid --dynamic --weights PATH/TO/yolov7.pt # 移动onnx文件到demo目录 - cp PATH/TO/yolo7.onnx PATH/TO/model_zoo/vision/yolov7/ + cp PATH/TO/yolov7.onnx PATH/TO/model_zoo/vision/yolov7/ ``` ## 安装FastDeploy diff --git a/model_zoo/vision/yolov7/cpp/README.md b/model_zoo/vision/yolov7/cpp/README.md index f216c1aec..655e98678 100644 --- a/model_zoo/vision/yolov7/cpp/README.md +++ b/model_zoo/vision/yolov7/cpp/README.md @@ -13,7 +13,7 @@ wget https://github.com/WongKinYiu/yolov7/releases/download/v0.1/yolov7.pt # 导出onnx格式文件 - python models/export.py --grid --dynamic --weights PATH/TO/yolo7.pt + python models/export.py --grid --dynamic --weights PATH/TO/yolov7.pt ``` @@ -31,7 +31,7 @@ cmake .. make -j # 移动onnx文件到demo目录 -cp PATH/TO/yolo7.onnx PATH/TO/model_zoo/vision/yolov7/cpp/build/ +cp PATH/TO/yolov7.onnx PATH/TO/model_zoo/vision/yolov7/cpp/build/ # 下载图片 wget https://raw.githubusercontent.com/WongKinYiu/yolov7/main/inference/images/horses.jpg