From bd0482f31401e118876a7a686ef2738e16812306 Mon Sep 17 00:00:00 2001 From: ziqi-jin <67993288+ziqi-jin@users.noreply.github.com> Date: Thu, 4 Aug 2022 15:30:59 +0800 Subject: [PATCH] Add model SCRFD support (#68) * first commit for yolov7 * pybind for yolov7 * CPP README.md * CPP README.md * modified yolov7.cc * README.md * python file modify * delete license in fastdeploy/ * repush the conflict part * README.md modified * README.md modified * file path modified * file path modified * file path modified * file path modified * file path modified * README modified * README modified * move some helpers to private * add examples for yolov7 * api.md modified * api.md modified * api.md modified * YOLOv7 * yolov7 release link * yolov7 release link * yolov7 release link * copyright * change some helpers to private * change variables to const and fix documents. * gitignore * Transfer some funtions to private member of class * Transfer some funtions to private member of class * Merge from develop (#9) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * first commit for yolor * for merge * Develop (#11) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * Yolor (#16) * Develop (#11) (#12) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * Develop (#13) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * documents * Develop (#14) * Fix compile problem in different python version (#26) * fix some usage problem in linux * Fix compile problem Co-authored-by: root * Add PaddleDetetion/PPYOLOE model support (#22) * add ppdet/ppyoloe * Add demo code and documents * add convert processor to vision (#27) * update .gitignore * Added checking for cmake include dir * fixed missing trt_backend option bug when init from trt * remove un-need data layout and add pre-check for dtype * changed RGB2BRG to BGR2RGB in ppcls model * add model_zoo yolov6 c++/python demo * fixed CMakeLists.txt typos * update yolov6 cpp/README.md * add yolox c++/pybind and model_zoo demo * move some helpers to private * fixed CMakeLists.txt typos * add normalize with alpha and beta * add version notes for yolov5/yolov6/yolox * add copyright to yolov5.cc * revert normalize * fixed some bugs in yolox * fixed examples/CMakeLists.txt to avoid conflicts * add convert processor to vision * format examples/CMakeLists summary * Fix bug while the inference result is empty with YOLOv5 (#29) * Add multi-label function for yolov5 * Update README.md Update doc * Update fastdeploy_runtime.cc fix variable option.trt_max_shape wrong name * Update runtime_option.md Update resnet model dynamic shape setting name from images to x * Fix bug when inference result boxes are empty * Delete detection.py Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <928090362@qq.com> * add is_dynamic for YOLO series (#22) * first commit for scrfd * bind file * documents for scrfd * documents for scrfd * documents for scrfd * documents for scrfd * modified for the pr comments * modified for the PR comments * fix second PR comments * delete the member fmc_, create fmc in the Postprocess function, fix the count for nms Co-authored-by: Jason Co-authored-by: root Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> Co-authored-by: huangjianhui <852142024@qq.com> Co-authored-by: Jason <928090362@qq.com> --- csrcs/fastdeploy/vision.h | 3 +- .../vision/deepinsight/deepinsight_pybind.cc | 47 +++ csrcs/fastdeploy/vision/deepinsight/scrfd.cc | 363 ++++++++++++++++++ csrcs/fastdeploy/vision/deepinsight/scrfd.h | 122 ++++++ csrcs/fastdeploy/vision/ppogg/yolov5lite.cc | 15 + csrcs/fastdeploy/vision/ppogg/yolov5lite.h | 7 + csrcs/fastdeploy/vision/vision_pybind.cc | 2 + .../vision/wongkinyiu/scaledyolov4.cc | 15 + .../vision/wongkinyiu/scaledyolov4.h | 7 + csrcs/fastdeploy/vision/wongkinyiu/yolor.cc | 17 +- csrcs/fastdeploy/vision/wongkinyiu/yolor.h | 7 + csrcs/fastdeploy/vision/wongkinyiu/yolov7.cc | 17 +- csrcs/fastdeploy/vision/wongkinyiu/yolov7.h | 7 + examples/vision/deepinsight_scrfd.cc | 51 +++ fastdeploy/vision/__init__.py | 1 + fastdeploy/vision/deepinsight/__init__.py | 158 ++++++++ model_zoo/vision/scrfd/README.md | 92 +++++ model_zoo/vision/scrfd/api.md | 71 ++++ model_zoo/vision/scrfd/cpp/CMakeLists.txt | 17 + model_zoo/vision/scrfd/cpp/README.md | 76 ++++ model_zoo/vision/scrfd/cpp/scrfd.cc | 44 +++ model_zoo/vision/scrfd/scrfd.py | 25 ++ 22 files changed, 1161 insertions(+), 3 deletions(-) create mode 100644 csrcs/fastdeploy/vision/deepinsight/deepinsight_pybind.cc create mode 100644 csrcs/fastdeploy/vision/deepinsight/scrfd.cc create mode 100644 csrcs/fastdeploy/vision/deepinsight/scrfd.h create mode 100644 examples/vision/deepinsight_scrfd.cc create mode 100644 fastdeploy/vision/deepinsight/__init__.py create mode 100644 model_zoo/vision/scrfd/README.md create mode 100644 model_zoo/vision/scrfd/api.md create mode 100644 model_zoo/vision/scrfd/cpp/CMakeLists.txt create mode 100644 model_zoo/vision/scrfd/cpp/README.md create mode 100644 model_zoo/vision/scrfd/cpp/scrfd.cc create mode 100644 model_zoo/vision/scrfd/scrfd.py diff --git a/csrcs/fastdeploy/vision.h b/csrcs/fastdeploy/vision.h index 2c0bdd1fa..7173f3d69 100644 --- a/csrcs/fastdeploy/vision.h +++ b/csrcs/fastdeploy/vision.h @@ -17,18 +17,19 @@ #ifdef ENABLE_VISION #include "fastdeploy/vision/biubug6/retinaface.h" #include "fastdeploy/vision/deepcam/yolov5face.h" +#include "fastdeploy/vision/deepinsight/scrfd.h" #include "fastdeploy/vision/linzaer/ultraface.h" #include "fastdeploy/vision/megvii/yolox.h" #include "fastdeploy/vision/meituan/yolov6.h" #include "fastdeploy/vision/ppcls/model.h" #include "fastdeploy/vision/ppdet/ppyoloe.h" +#include "fastdeploy/vision/ppogg/yolov5lite.h" #include "fastdeploy/vision/ppseg/model.h" #include "fastdeploy/vision/rangilyu/nanodet_plus.h" #include "fastdeploy/vision/ultralytics/yolov5.h" #include "fastdeploy/vision/wongkinyiu/scaledyolov4.h" #include "fastdeploy/vision/wongkinyiu/yolor.h" #include "fastdeploy/vision/wongkinyiu/yolov7.h" -#include "fastdeploy/vision/ppogg/yolov5lite.h" #endif #include "fastdeploy/vision/visualize/visualize.h" diff --git a/csrcs/fastdeploy/vision/deepinsight/deepinsight_pybind.cc b/csrcs/fastdeploy/vision/deepinsight/deepinsight_pybind.cc new file mode 100644 index 000000000..459e89b7e --- /dev/null +++ b/csrcs/fastdeploy/vision/deepinsight/deepinsight_pybind.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindDeepinsight(pybind11::module& m) { + auto deepinsight_module = m.def_submodule( + "deepinsight", "https://github.com/deepinsight"); + pybind11::class_( + deepinsight_module, "SCRFD") + .def(pybind11::init()) + .def("predict", + [](vision::deepinsight::SCRFD& self, pybind11::array& data, + float conf_threshold, float nms_iou_threshold) { + auto mat = PyArrayToCvMat(data); + vision::FaceDetectionResult res; + self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); + return res; + }) + .def_readwrite("size", &vision::deepinsight::SCRFD::size) + .def_readwrite("padding_value", + &vision::deepinsight::SCRFD::padding_value) + .def_readwrite("is_mini_pad", &vision::deepinsight::SCRFD::is_mini_pad) + .def_readwrite("is_no_pad", &vision::deepinsight::SCRFD::is_no_pad) + .def_readwrite("is_scale_up", &vision::deepinsight::SCRFD::is_scale_up) + .def_readwrite("stride", &vision::deepinsight::SCRFD::stride) + .def_readwrite("use_kps", &vision::deepinsight::SCRFD::use_kps) + .def_readwrite("max_nms", &vision::deepinsight::SCRFD::max_nms) + .def_readwrite("downsample_strides", + &vision::deepinsight::SCRFD::downsample_strides) + .def_readwrite("num_anchors", &vision::deepinsight::SCRFD::num_anchors) + .def_readwrite("landmarks_per_face", + &vision::deepinsight::SCRFD::landmarks_per_face); +} +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/deepinsight/scrfd.cc b/csrcs/fastdeploy/vision/deepinsight/scrfd.cc new file mode 100644 index 000000000..d86331fbe --- /dev/null +++ b/csrcs/fastdeploy/vision/deepinsight/scrfd.cc @@ -0,0 +1,363 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/deepinsight/scrfd.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { + +namespace vision { + +namespace deepinsight { + +void SCRFD::LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill, bool scale_up, int stride) { + float scale = + std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width()); + if (!scale_up) { + scale = std::min(scale, 1.0f); + } + + int resize_h = int(round(mat->Height() * scale)); + int resize_w = int(round(mat->Width() * scale)); + + int pad_w = size[0] - resize_w; + int pad_h = size[1] - resize_h; + if (_auto) { + pad_h = pad_h % stride; + pad_w = pad_w % stride; + } else if (scale_fill) { + pad_h = 0; + pad_w = 0; + resize_h = size[1]; + resize_w = size[0]; + } + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + if (pad_h > 0 || pad_w > 0) { + float half_h = pad_h * 1.0 / 2; + int top = int(round(half_h - 0.1)); + int bottom = int(round(half_h + 0.1)); + float half_w = pad_w * 1.0 / 2; + int left = int(round(half_w - 0.1)); + int right = int(round(half_w + 0.1)); + Pad::Run(mat, top, bottom, left, right, color); + } +} + +SCRFD::SCRFD(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, const Frontend& model_format) { + if (model_format == Frontend::ONNX) { + valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端 + valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端 + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool SCRFD::Initialize() { + // parameters for preprocess + use_kps = true; + size = {640, 640}; + padding_value = {0.0, 0.0, 0.0}; + is_mini_pad = false; + is_no_pad = false; + is_scale_up = false; + stride = 32; + downsample_strides = {8, 16, 32}; + num_anchors = 2; + landmarks_per_face = 5; + center_points_is_update_ = false; + max_nms = 30000; + // num_outputs = use_kps ? 9 : 6; + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } + + return true; +} + +bool SCRFD::Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { + float ratio = std::min(size[1] * 1.0f / static_cast(mat->Height()), + size[0] * 1.0f / static_cast(mat->Width())); + if (ratio != 1.0) { + int interp = cv::INTER_AREA; + if (ratio > 1.0) { + interp = cv::INTER_LINEAR; + } + int resize_h = int(mat->Height() * ratio); + int resize_w = int(mat->Width() * ratio); + Resize::Run(mat, resize_w, resize_h, -1, -1, interp); + } + // scrfd's preprocess steps + // 1. letterbox + // 2. BGR->RGB + // 3. HWC->CHW + SCRFD::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad, + is_scale_up, stride); + + BGR2RGB::Run(mat); + // Normalize::Run(mat, std::vector(mat->Channels(), 0.0), + // std::vector(mat->Channels(), 1.0)); + // Compute `result = mat * alpha + beta` directly by channel + // Original Repo/tools/scrfd.py: cv2.dnn.blobFromImage(img, 1.0/128, + // input_size, (127.5, 127.5, 127.5), swapRB=True) + std::vector alpha = {1.f / 128.f, 1.f / 128.f, 1.f / 128.f}; + std::vector beta = {-127.5f / 128.f, -127.5f / 128.f, -127.5f / 128.f}; + Convert::Run(mat, alpha, beta); + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +void SCRFD::GeneratePoints() { + if (center_points_is_update_ && !is_dynamic_input_) { + return; + } + // 8, 16, 32 + for (auto local_stride : downsample_strides) { + unsigned int num_grid_w = size[0] / local_stride; + unsigned int num_grid_h = size[1] / local_stride; + // y + for (unsigned int i = 0; i < num_grid_h; ++i) { + // x + for (unsigned int j = 0; j < num_grid_w; ++j) { + // num_anchors, col major + for (unsigned int k = 0; k < num_anchors; ++k) { + SCRFDPoint point; + point.cx = static_cast(j); + point.cy = static_cast(i); + center_points_[local_stride].push_back(point); + } + } + } + } + + center_points_is_update_ = true; +} + +bool SCRFD::Postprocess( + std::vector& infer_result, FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold) { + // number of downsample_strides + int fmc = downsample_strides.size(); + // scrfd has 6,9,10,15 output tensors + FDASSERT((infer_result.size() == 9 || infer_result.size() == 6 || + infer_result.size() == 10 || infer_result.size() == 15), + "The default number of output tensor must be 6, 9, 10, or 15 " + "according to scrfd."); + FDASSERT((fmc == 3 || fmc == 5), "The fmc must be 3 or 5"); + FDASSERT((infer_result.at(0).shape[0] == 1), "Only support batch =1 now."); + for (int i = 0; i < fmc; ++i) { + if (infer_result.at(i).dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + } + int total_num_boxes = 0; + // compute the reserve space. + for (int f = 0; f < fmc; ++f) { + total_num_boxes += infer_result.at(f).shape[1]; + }; + GeneratePoints(); + result->Clear(); + // scale the boxes to the origin image shape + auto iter_out = im_info.find("output_shape"); + auto iter_ipt = im_info.find("input_shape"); + FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(), + "Cannot find input_shape or output_shape from im_info."); + float out_h = iter_out->second[0]; + float out_w = iter_out->second[1]; + float ipt_h = iter_ipt->second[0]; + float ipt_w = iter_ipt->second[1]; + float scale = std::min(out_h / ipt_h, out_w / ipt_w); + float pad_h = (out_h - ipt_h * scale) / 2.0f; + float pad_w = (out_w - ipt_w * scale) / 2.0f; + if (is_mini_pad) { + // 和 LetterBox中_auto=true的处理逻辑对应 + pad_h = static_cast(static_cast(pad_h) % stride); + pad_w = static_cast(static_cast(pad_w) % stride); + } + // must be setup landmarks_per_face before reserve + result->landmarks_per_face = landmarks_per_face; + result->Reserve(total_num_boxes); + unsigned int count = 0; + // loop each stride + for (int f = 0; f < fmc; ++f) { + float* score_ptr = static_cast(infer_result.at(f).Data()); + float* bbox_ptr = static_cast(infer_result.at(f + fmc).Data()); + const unsigned int num_points = infer_result.at(f).shape[1]; + int current_stride = downsample_strides[f]; + auto& stride_points = center_points_[current_stride]; + // loop each anchor + for (unsigned int i = 0; i < num_points; ++i) { + const float cls_conf = score_ptr[i]; + if (cls_conf < conf_threshold) continue; // filter + auto& point = stride_points.at(i); + const float cx = point.cx; // cx + const float cy = point.cy; // cy + // bbox + const float* offsets = bbox_ptr + i * 4; + float l = offsets[0]; // left + float t = offsets[1]; // top + float r = offsets[2]; // right + float b = offsets[3]; // bottom + + float x1 = + ((cx - l) * static_cast(current_stride) - static_cast(pad_w)) / scale; // cx - l x1 + float y1 = + ((cy - t) * static_cast(current_stride) - static_cast(pad_h)) / scale; // cy - t y1 + float x2 = + ((cx + r) * static_cast(current_stride) - static_cast(pad_w)) / scale; // cx + r x2 + float y2 = + ((cy + b) * static_cast(current_stride) - static_cast(pad_h)) / scale; // cy + b y2 + result->boxes.emplace_back(std::array{x1, y1, x2, y2}); + result->scores.push_back(cls_conf); + if (use_kps) { + float* landmarks_ptr = + static_cast(infer_result.at(f + 2 * fmc).Data()); + // landmarks + const float* kps_offsets = landmarks_ptr + i * (landmarks_per_face * 2); + for (unsigned int j = 0; j < landmarks_per_face * 2; j += 2) { + float kps_l = kps_offsets[j]; + float kps_t = kps_offsets[j + 1]; + float kps_x = ((cx + kps_l) * static_cast(current_stride) - static_cast(pad_w)) / + scale; // cx + l x + float kps_y = ((cy + kps_t) * static_cast(current_stride) - static_cast(pad_h)) / + scale; // cy + t y + result->landmarks.emplace_back(std::array{kps_x, kps_y}); + } + } + count += 1; // limit boxes for nms. + if (count > max_nms) { + break; + } + } + } + + // fetch original image shape + FDASSERT((iter_ipt != im_info.end()), + "Cannot find input_shape from im_info."); + + if (result->boxes.size() == 0) { + return true; + } + + utils::NMS(result, nms_iou_threshold); + + // scale and clip box + for (size_t i = 0; i < result->boxes.size(); ++i) { + result->boxes[i][0] = std::max(result->boxes[i][0], 0.0f); + result->boxes[i][1] = std::max(result->boxes[i][1], 0.0f); + result->boxes[i][2] = std::max(result->boxes[i][2], 0.0f); + result->boxes[i][3] = std::max(result->boxes[i][3], 0.0f); + result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f); + result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f); + result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f); + result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f); + } + // scale and clip landmarks + for (size_t i = 0; i < result->landmarks.size(); ++i) { + result->landmarks[i][0] = std::max(result->landmarks[i][0], 0.0f); + result->landmarks[i][1] = std::max(result->landmarks[i][1], 0.0f); + result->landmarks[i][0] = std::min(result->landmarks[i][0], ipt_w - 1.0f); + result->landmarks[i][1] = std::min(result->landmarks[i][1], ipt_h - 1.0f); + } + return true; +} + +bool SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold, float nms_iou_threshold) { +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_START(0) +#endif + Mat mat(*im); + std::vector input_tensors(1); + + std::map> im_info; + + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + im_info["output_shape"] = {static_cast(mat.Height()), + static_cast(mat.Width())}; + + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(0, "Preprocess") + TIMERECORD_START(1) +#endif + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(1, "Inference") + TIMERECORD_START(2) +#endif + + if (!Postprocess(output_tensors, result, im_info, conf_threshold, + nms_iou_threshold)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + +#ifdef FASTDEPLOY_DEBUG + TIMERECORD_END(2, "Postprocess") +#endif + return true; +} + +} // namespace deepinsight +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/csrcs/fastdeploy/vision/deepinsight/scrfd.h b/csrcs/fastdeploy/vision/deepinsight/scrfd.h new file mode 100644 index 000000000..a84eab5f1 --- /dev/null +++ b/csrcs/fastdeploy/vision/deepinsight/scrfd.h @@ -0,0 +1,122 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { + +namespace vision { + +namespace deepinsight { + +class FASTDEPLOY_DECL SCRFD : public FastDeployModel { + public: + // 当model_format为ONNX时,无需指定params_file + // 当model_format为Paddle时,则需同时指定model_file & params_file + SCRFD(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX); + + // 定义模型的名称 + std::string ModelName() const { return "deepinsight/scrfd"; } + + // 模型预测接口,即用户调用的接口 + // im 为用户的输入数据,目前对于CV均定义为cv::Mat + // result 为模型预测的输出结构体 + // conf_threshold 为后处理的参数 + // nms_iou_threshold 为后处理的参数 + virtual bool Predict(cv::Mat* im, FaceDetectionResult* result, + float conf_threshold = 0.25f, + float nms_iou_threshold = 0.4f); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height), default (640, 640) + std::vector size; + // downsample strides (namely, steps) for SCRFD to + // generate anchors, will take (8,16,32) as default values. + // padding value, size should be same with Channels + std::vector padding_value; + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad; + // while is_mini_pad = false and is_no_pad = true, will resize the image to + // the set size + bool is_no_pad; + // if is_scale_up is false, the input image only can be zoom out, the maximum + // resize scale cannot exceed 1.0 + bool is_scale_up; + // padding stride, for is_mini_pad + int stride; + // for offseting the boxes by classes when using NMS + std::vector downsample_strides; + // landmarks_per_face, default 5 in SCRFD + int landmarks_per_face; + // are the outputs of onnx file with key points features or not + bool use_kps; + // the upperbond number of boxes processed by nms. + int max_nms; + // number anchors of each stride + unsigned int num_anchors; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); + + // 输入图像预处理操作 + // Mat为FastDeploy定义的数据结构 + // FDTensor为预处理后的Tensor数据,传给后端进行推理 + // im_info为预处理过程保存的数据,在后处理中需要用到 + bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + + // 后端推理结果后处理,输出给用户 + // infer_result 为后端推理后的输出Tensor + // result 为模型预测的结果 + // im_info 为预处理记录的信息,后处理用于还原box + // conf_threshold 后处理时过滤box的置信度阈值 + // nms_iou_threshold 后处理时NMS设定的iou阈值 + bool Postprocess(std::vector& infer_result, + FaceDetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold); + + void GeneratePoints(); + + // 对图片进行LetterBox处理 + // mat 为读取到的原图 + // size 为输入模型的图像尺寸 + void LetterBox(Mat* mat, const std::vector& size, + const std::vector& color, bool _auto, + bool scale_fill = false, bool scale_up = true, + int stride = 32); + + bool is_dynamic_input_; + + bool center_points_is_update_; + + typedef struct { + float cx; + float cy; + } SCRFDPoint; + + std::unordered_map> center_points_; +}; +} // namespace deepinsight +} // namespace vision +} // namespace fastdeploy diff --git a/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc b/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc index 320867f58..a84ead937 100644 --- a/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc +++ b/csrcs/fastdeploy/vision/ppogg/yolov5lite.cc @@ -118,6 +118,21 @@ bool YOLOv5Lite::Initialize() { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } return true; } diff --git a/csrcs/fastdeploy/vision/ppogg/yolov5lite.h b/csrcs/fastdeploy/vision/ppogg/yolov5lite.h index 3eb556cfa..77209e7ae 100644 --- a/csrcs/fastdeploy/vision/ppogg/yolov5lite.h +++ b/csrcs/fastdeploy/vision/ppogg/yolov5lite.h @@ -126,6 +126,13 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel { void GenerateAnchors(const std::vector& size, const std::vector& downsample_strides, std::vector* anchors, const int num_anchors = 3); + + // whether to inference with dynamic shape (e.g ONNX export with dynamic shape + // or not.) + // while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This + // value will + // auto check by fastdeploy after the internal Runtime already initialized. + bool is_dynamic_input_; }; } // namespace ppogg } // namespace vision diff --git a/csrcs/fastdeploy/vision/vision_pybind.cc b/csrcs/fastdeploy/vision/vision_pybind.cc index 79aa87635..e0aab6fdc 100644 --- a/csrcs/fastdeploy/vision/vision_pybind.cc +++ b/csrcs/fastdeploy/vision/vision_pybind.cc @@ -28,6 +28,7 @@ void BindRangiLyu(pybind11::module& m); void BindLinzaer(pybind11::module& m); void BindBiubug6(pybind11::module& m); void BindPpogg(pybind11::module& m); +void BindDeepinsight(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); #endif @@ -75,6 +76,7 @@ void BindVision(pybind11::module& m) { BindLinzaer(m); BindBiubug6(m); BindPpogg(m); + BindDeepinsight(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); #endif diff --git a/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.cc b/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.cc index 7321fc01b..a562c9b27 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.cc +++ b/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.cc @@ -89,6 +89,21 @@ bool ScaledYOLOv4::Initialize() { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } return true; } diff --git a/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.h b/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.h index 39066a29e..c85b58d4c 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.h +++ b/csrcs/fastdeploy/vision/wongkinyiu/scaledyolov4.h @@ -90,6 +90,13 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel { const std::vector& color, bool _auto, bool scale_fill = false, bool scale_up = true, int stride = 32); + + // whether to inference with dynamic shape (e.g ONNX export with dynamic shape + // or not.) + // while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This + // value will + // auto check by fastdeploy after the internal Runtime already initialized. + bool is_dynamic_input_; }; } // namespace wongkinyiu } // namespace vision diff --git a/csrcs/fastdeploy/vision/wongkinyiu/yolor.cc b/csrcs/fastdeploy/vision/wongkinyiu/yolor.cc index 070ea72e6..7de994f2a 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/yolor.cc +++ b/csrcs/fastdeploy/vision/wongkinyiu/yolor.cc @@ -87,6 +87,21 @@ bool YOLOR::Initialize() { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } return true; } @@ -176,7 +191,7 @@ bool YOLOR::Postprocess( float pad_h = (out_h - ipt_h * scale) / 2.0f; float pad_w = (out_w - ipt_w * scale) / 2.0f; if (is_mini_pad) { - // 和 LetterBox中_auto=true的处理逻辑对应 + // 和 LetterBox中_auto=true的处理逻辑对应 pad_h = static_cast(static_cast(pad_h) % stride); pad_w = static_cast(static_cast(pad_w) % stride); } diff --git a/csrcs/fastdeploy/vision/wongkinyiu/yolor.h b/csrcs/fastdeploy/vision/wongkinyiu/yolor.h index 7597f42d3..05bbd6421 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/yolor.h +++ b/csrcs/fastdeploy/vision/wongkinyiu/yolor.h @@ -89,6 +89,13 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel { const std::vector& color, bool _auto, bool scale_fill = false, bool scale_up = true, int stride = 32); + + // whether to inference with dynamic shape (e.g ONNX export with dynamic shape + // or not.) + // while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This + // value will + // auto check by fastdeploy after the internal Runtime already initialized. + bool is_dynamic_input_; }; } // namespace wongkinyiu } // namespace vision diff --git a/csrcs/fastdeploy/vision/wongkinyiu/yolov7.cc b/csrcs/fastdeploy/vision/wongkinyiu/yolov7.cc index 457f8800c..6f603c87f 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/yolov7.cc +++ b/csrcs/fastdeploy/vision/wongkinyiu/yolov7.cc @@ -88,6 +88,21 @@ bool YOLOv7::Initialize() { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; return false; } + // Check if the input shape is dynamic after Runtime already initialized, + // Note that, We need to force is_mini_pad 'false' to keep static + // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. + is_dynamic_input_ = false; + auto shape = InputInfoOfRuntime(0).shape; + for (int i = 0; i < shape.size(); ++i) { + // if height or width is dynamic + if (i >= 2 && shape[i] <= 0) { + is_dynamic_input_ = true; + break; + } + } + if (!is_dynamic_input_) { + is_mini_pad = false; + } return true; } @@ -177,7 +192,7 @@ bool YOLOv7::Postprocess( float pad_h = (out_h - ipt_h * scale) / 2.0f; float pad_w = (out_w - ipt_w * scale) / 2.0f; if (is_mini_pad) { - // 和 LetterBox中_auto=true的处理逻辑对应 + // 和 LetterBox中_auto=true的处理逻辑对应 pad_h = static_cast(static_cast(pad_h) % stride); pad_w = static_cast(static_cast(pad_w) % stride); } diff --git a/csrcs/fastdeploy/vision/wongkinyiu/yolov7.h b/csrcs/fastdeploy/vision/wongkinyiu/yolov7.h index 64e18ad47..595530b9c 100644 --- a/csrcs/fastdeploy/vision/wongkinyiu/yolov7.h +++ b/csrcs/fastdeploy/vision/wongkinyiu/yolov7.h @@ -89,6 +89,13 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel { const std::vector& color, bool _auto, bool scale_fill = false, bool scale_up = true, int stride = 32); + + // whether to inference with dynamic shape (e.g ONNX export with dynamic shape + // or not.) + // while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This + // value will + // auto check by fastdeploy after the internal Runtime already initialized. + bool is_dynamic_input_; }; } // namespace wongkinyiu } // namespace vision diff --git a/examples/vision/deepinsight_scrfd.cc b/examples/vision/deepinsight_scrfd.cc new file mode 100644 index 000000000..0ff68db93 --- /dev/null +++ b/examples/vision/deepinsight_scrfd.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" +int main() { + namespace vis = fastdeploy::vision; + + std::string model_file = "../resources/models/SCRFD.onnx"; + std::string img_path = "../resources/images/test_face_det.jpg"; + std::string vis_path = "../resources/outputs/deepsight_scrfd_vis_result.jpg"; + + auto model = vis::deepinsight::SCRFD(model_file); + model.size = {640, 640}; // (width, height) + if (!model.Initialized()) { + std::cerr << "Init Failed! Model: " << model_file << std::endl; + return -1; + } else { + std::cout << "Init Done! Model:" << model_file << std::endl; + } + model.EnableDebug(); + + cv::Mat im = cv::imread(img_path); + cv::Mat vis_im = im.clone(); + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res, 0.3f, 0.3f)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } else { + std::cout << "Prediction Done!" << std::endl; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite(vis_path, vis_im); + std::cout << "Detect Done! Saved: " << vis_path << std::endl; + return 0; +} diff --git a/fastdeploy/vision/__init__.py b/fastdeploy/vision/__init__.py index 067659570..223b7b1a8 100644 --- a/fastdeploy/vision/__init__.py +++ b/fastdeploy/vision/__init__.py @@ -27,3 +27,4 @@ from . import rangilyu from . import linzaer from . import biubug6 from . import ppogg +from . import deepinsight diff --git a/fastdeploy/vision/deepinsight/__init__.py b/fastdeploy/vision/deepinsight/__init__.py new file mode 100644 index 000000000..106832455 --- /dev/null +++ b/fastdeploy/vision/deepinsight/__init__.py @@ -0,0 +1,158 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from ... import FastDeployModel, Frontend +from ... import fastdeploy_main as C + + +class SCRFD(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=Frontend.ONNX): + # 调用基函数进行backend_option的初始化 + # 初始化后的option保存在self._runtime_option + super(SCRFD, self).__init__(runtime_option) + + self._model = C.vision.deepinsight.SCRFD( + model_file, params_file, self._runtime_option, model_format) + # 通过self.initialized判断整个模型的初始化是否成功 + assert self.initialized, "SCRFD initialize failed." + + def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3): + return self._model.predict(input_image, conf_threshold, + nms_iou_threshold) + + # 一些跟SCRFD模型有关的属性封装 + # 多数是预处理相关,可通过修改如model.size = [640, 640]改变预处理时resize的大小(前提是模型支持) + @property + def size(self): + return self._model.size + + @property + def padding_value(self): + return self._model.padding_value + + @property + def is_no_pad(self): + return self._model.is_no_pad + + @property + def is_mini_pad(self): + return self._model.is_mini_pad + + @property + def is_scale_up(self): + return self._model.is_scale_up + + @property + def stride(self): + return self._model.stride + + @property + def downsample_strides(self): + return self._model.downsample_strides + + @property + def landmarks_per_face(self): + return self._model.landmarks_per_face + + @property + def use_kps(self): + return self._model.use_kps + + @property + def max_nms(self): + return self._model.max_nms + + @property + def num_anchors(self): + return self._model.num_anchors + + @size.setter + def size(self, wh): + assert isinstance(wh, [list, tuple]),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @padding_value.setter + def padding_value(self, value): + assert isinstance( + value, + list), "The value to set `padding_value` must be type of list." + self._model.padding_value = value + + @is_no_pad.setter + def is_no_pad(self, value): + assert isinstance( + value, bool), "The value to set `is_no_pad` must be type of bool." + self._model.is_no_pad = value + + @is_mini_pad.setter + def is_mini_pad(self, value): + assert isinstance( + value, + bool), "The value to set `is_mini_pad` must be type of bool." + self._model.is_mini_pad = value + + @is_scale_up.setter + def is_scale_up(self, value): + assert isinstance( + value, + bool), "The value to set `is_scale_up` must be type of bool." + self._model.is_scale_up = value + + @stride.setter + def stride(self, value): + assert isinstance( + value, int), "The value to set `stride` must be type of int." + self._model.stride = value + + @downsample_strides.setter + def downsample_strides(self, value): + assert isinstance( + value, + list), "The value to set `downsample_strides` must be type of list." + self._model.downsample_strides = value + + @landmarks_per_face.setter + def landmarks_per_face(self, value): + assert isinstance( + value, + int), "The value to set `landmarks_per_face` must be type of int." + self._model.landmarks_per_face = value + + @use_kps.setter + def use_kps(self, value): + assert isinstance( + value, bool), "The value to set `use_kps` must be type of bool." + self._model.use_kps = value + + @max_nms.setter + def max_nms(self, value): + assert isinstance( + value, int), "The value to set `max_nms` must be type of int." + self._model.max_nms = value + + @num_anchors.setter + def num_anchors(self, value): + assert isinstance( + value, int), "The value to set `num_anchors` must be type of int." + self._model.num_anchors = value diff --git a/model_zoo/vision/scrfd/README.md b/model_zoo/vision/scrfd/README.md new file mode 100644 index 000000000..4424f59a3 --- /dev/null +++ b/model_zoo/vision/scrfd/README.md @@ -0,0 +1,92 @@ +# 编译SCRFD示例 + +当前支持模型版本为:[SCRFD CID:17cdeab](https://github.com/deepinsight/insightface/tree/17cdeab12a35efcebc2660453a8cbeae96e20950) + +本文档说明如何进行[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)的快速部署推理。本目录结构如下 + +``` +. +├── cpp +│   ├── CMakeLists.txt +│   ├── README.md +│   └── scrfd.cc +├── README.md +└── scrfd.py +``` + +## 获取ONNX文件 + +- 手动获取 + + 访问[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)官方github库,按照指引下载安装,下载`scrfd.pt` 模型,利用 `tools/scrfd2onnx.py` 得到`onnx`格式文件。 + + + + ``` + #下载scrfd模型文件 + e.g. download from https://onedrive.live.com/?authkey=%21ABbFJx2JMhNjhNA&id=4A83B6B633B029CC%215542&cid=4A83B6B633B029CC + + # 安装官方库配置环境,此版本导出环境为: + - 手动配置环境 + torch==1.8.0 + mmcv==1.3.5 + mmdet==2.7.0 + + - 通过docker配置 + docker pull qyjdefdocker/onnx-scrfd-converter:v0.3 + + # 导出onnx格式文件 + - 手动生成 + python tools/scrfd2onnx.py configs/scrfd/scrfd_500m.py weights/scrfd_500m.pth --shape 640 --input-img face-xxx.jpg + + - docker + docker的onnx目录中已有生成好的onnx文件 + + + # 移动onnx文件到demo目录 + cp PATH/TO/SCRFD.onnx PATH/TO/model_zoo/vision/scrfd/ + ``` + +## 安装FastDeploy + +使用如下命令安装FastDeploy,注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu` + +``` +# 安装fastdeploy-python工具 +pip install fastdeploy-python + +# 安装vision-cpu模块 +fastdeploy install vision-cpu +``` +## Python部署 + +执行如下代码即会自动下载测试图片 +``` +python scrfd.py +``` + +执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score] +437.670410,194.262772, 478.729828, 244.633911, 0.912465 +418.303650,118.277687, 455.877838, 169.209564, 0.911748 +269.449493,280.810608, 319.466614, 342.681213, 0.908530 +775.553955,237.509979, 814.626526, 286.252350, 0.901296 +565.155945,303.849670, 608.786255, 356.025726, 0.898307 +411.813477,296.117584, 454.560394, 353.151367, 0.889968 +688.620239,153.063812, 728.825195, 204.860321, 0.888146 +686.523071,304.881104, 732.901245, 364.715088, 0.885789 +194.658829,236.657883, 234.194748, 289.099701, 0.881143 +137.273422,286.025787, 183.479523, 344.614441, 0.877399 +289.256775,148.388992, 326.087769, 197.035645, 0.875090 +182.943939,154.105682, 221.422440, 204.460495, 0.871119 +330.301849,207.786499, 367.546692, 260.813232, 0.869559 +659.884216,254.861847, 701.580017, 307.984711, 0.869249 +550.305359,232.336868, 591.702026, 281.101532, 0.866158 +567.473511,127.402367, 604.959839, 175.831696, 0.858938 +``` + +## 其它文档 + +- [C++部署](./cpp/README.md) +- [SCRFD API文档](./api.md) diff --git a/model_zoo/vision/scrfd/api.md b/model_zoo/vision/scrfd/api.md new file mode 100644 index 000000000..442bd4a25 --- /dev/null +++ b/model_zoo/vision/scrfd/api.md @@ -0,0 +1,71 @@ +# SCRFD API说明 + +## Python API + +### SCRFD类 +``` +fastdeploy.vision.deepinsight.SCRFD(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX) +``` +SCRFD模型加载和初始化,当model_format为`fd.Frontend.ONNX`时,只需提供model_file,如`SCRFD.onnx`;当model_format为`fd.Frontend.PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### predict函数 +> ``` +> SCRFD.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> ``` +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +示例代码参考[scrfd.py](./scrfd.py) + + +## C++ API + +### SCRFD类 +``` +fastdeploy::vision::deepinsight::SCRFD( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` +SCRFD模型加载和初始化,当model_format为`Frontend::ONNX`时,只需提供model_file,如`SCRFD.onnx`;当model_format为`Frontend::PADDLE`时,则需同时提供model_file和params_file。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式 + +#### Predict函数 +> ``` +> SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result, +> float conf_threshold = 0.25, +> float nms_iou_threshold = 0.5) +> ``` +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度 +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +示例代码参考[cpp/scrfd.cc](cpp/scrfd.cc) + +## 其它API使用 + +- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md) diff --git a/model_zoo/vision/scrfd/cpp/CMakeLists.txt b/model_zoo/vision/scrfd/cpp/CMakeLists.txt new file mode 100644 index 000000000..e63971ba1 --- /dev/null +++ b/model_zoo/vision/scrfd/cpp/CMakeLists.txt @@ -0,0 +1,17 @@ +PROJECT(scrfd_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.16) + +# 在低版本ABI环境中,通过如下代码进行兼容性编译 +# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + +# 指定下载解压后的fastdeploy库路径 +set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/) + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(scrfd_demo ${PROJECT_SOURCE_DIR}/scrfd.cc) +# 添加FastDeploy库依赖 +target_link_libraries(scrfd_demo ${FASTDEPLOY_LIBS}) diff --git a/model_zoo/vision/scrfd/cpp/README.md b/model_zoo/vision/scrfd/cpp/README.md new file mode 100644 index 000000000..fe2ee64d3 --- /dev/null +++ b/model_zoo/vision/scrfd/cpp/README.md @@ -0,0 +1,76 @@ +# 编译SCRFD示例 + +当前支持模型版本为:[SCRFD CID:17cdeab](https://github.com/deepinsight/insightface/tree/17cdeab12a35efcebc2660453a8cbeae96e20950) + +本文档说明如何进行[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)的快速部署推理。本目录结构如下 + +## 获取ONNX文件 + +- 手动获取 + + 访问[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)官方github库,按照指引下载安装,下载`scrfd.pt` 模型,利用 `tools/scrfd2onnx.py` 得到`onnx`格式文件。 + + + ``` + #下载scrfd模型文件 + e.g. download from https://onedrive.live.com/?authkey=%21ABbFJx2JMhNjhNA&id=4A83B6B633B029CC%215542&cid=4A83B6B633B029CC + + # 安装官方库配置环境,此版本导出环境为: + - 手动配置环境 + torch==1.8.0 + mmcv==1.3.5 + mmdet==2.7.0 + + - 通过docker配置 + docker pull qyjdefdocker/onnx-scrfd-converter:v0.3 + + # 导出onnx格式文件 + - 手动生成 + python tools/scrfd2onnx.py configs/scrfd/scrfd_500m.py weights/scrfd_500m.pth --shape 640 --input-img face-xxx.jpg + + - docker + docker的onnx目录中已有生成好的onnx文件 + + +## 运行demo + +``` +# 下载和解压预测库 +wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz +tar xvf fastdeploy-linux-x64-0.0.3.tgz + +# 编译示例代码 +mkdir build & cd build +cmake .. +make -j + +# 移动onnx文件到demo目录 +cp PATH/TO/SCRFD.onnx PATH/TO/model_zoo/vision/scrfd/cpp/build/ + +# 下载图片 +wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg + +# 执行 +./scrfd_demo +``` + +执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示 +``` +FaceDetectionResult: [xmin, ymin, xmax, ymax, score] +437.670410,194.262772, 478.729828, 244.633911, 0.912465 +418.303650,118.277687, 455.877838, 169.209564, 0.911748 +269.449493,280.810608, 319.466614, 342.681213, 0.908530 +775.553955,237.509979, 814.626526, 286.252350, 0.901296 +565.155945,303.849670, 608.786255, 356.025726, 0.898307 +411.813477,296.117584, 454.560394, 353.151367, 0.889968 +688.620239,153.063812, 728.825195, 204.860321, 0.888146 +686.523071,304.881104, 732.901245, 364.715088, 0.885789 +194.658829,236.657883, 234.194748, 289.099701, 0.881143 +137.273422,286.025787, 183.479523, 344.614441, 0.877399 +289.256775,148.388992, 326.087769, 197.035645, 0.875090 +182.943939,154.105682, 221.422440, 204.460495, 0.871119 +330.301849,207.786499, 367.546692, 260.813232, 0.869559 +659.884216,254.861847, 701.580017, 307.984711, 0.869249 +550.305359,232.336868, 591.702026, 281.101532, 0.866158 +567.473511,127.402367, 604.959839, 175.831696, 0.858938 +``` diff --git a/model_zoo/vision/scrfd/cpp/scrfd.cc b/model_zoo/vision/scrfd/cpp/scrfd.cc new file mode 100644 index 000000000..72dbeb4c7 --- /dev/null +++ b/model_zoo/vision/scrfd/cpp/scrfd.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +int main() { + namespace vis = fastdeploy::vision; + auto model = vis::deepinsight::SCRFD("SCRFD.onnx"); + if (!model.Initialized()) { + std::cerr << "Init Failed." << std::endl; + return -1; + } + cv::Mat im = cv::imread("test_lite_face_detector_3.jpg"); + cv::Mat vis_im = im.clone(); + + // 如果导入不带有关键点预测的模型,请修改模型参数 use_kps 和 landmarks_per_face,示例如下 + // model.landmarks_per_face = 0; + // model.use_kps = false; + + vis::FaceDetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Prediction Failed." << std::endl; + return -1; + } + + // 输出预测框结果 + std::cout << res.Str() << std::endl; + + // 可视化预测结果 + vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f); + cv::imwrite("vis_result.jpg", vis_im); + return 0; +} diff --git a/model_zoo/vision/scrfd/scrfd.py b/model_zoo/vision/scrfd/scrfd.py new file mode 100644 index 000000000..1d4ae8c76 --- /dev/null +++ b/model_zoo/vision/scrfd/scrfd.py @@ -0,0 +1,25 @@ +import fastdeploy as fd +import cv2 + +# 下载模型和测试图片 +test_jpg_url = "https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg" +fd.download(test_jpg_url, ".", show_progress=True) + +# 加载模型 +model = fd.vision.deepinsight.SCRFD("SCRFD.onnx") + +# 如果导入不带有关键点预测的模型,请修改模型参数 use_kps 和 landmarks_per_face,示例如下 +# model.use_kps = False +# model.landmarks_per_face = 0 + +# 预测图片 +im = cv2.imread("test_lite_face_detector_3.jpg") +result = model.predict(im, conf_threshold=0.5, nms_iou_threshold=0.5) + +# 可视化结果 +fd.vision.visualize.vis_face_detection(im, result) +cv2.imwrite("vis_result.jpg", im) + +# 输出预测结果 +print(result) +print(model.runtime_option)