Merge branch 'develop' into new_ppdet1

This commit is contained in:
Jason
2022-08-04 17:30:54 +08:00
committed by GitHub
22 changed files with 1160 additions and 2 deletions

View File

@@ -17,6 +17,7 @@
#ifdef ENABLE_VISION
#include "fastdeploy/vision/biubug6/retinaface.h"
#include "fastdeploy/vision/deepcam/yolov5face.h"
#include "fastdeploy/vision/deepinsight/scrfd.h"
#include "fastdeploy/vision/linzaer/ultraface.h"
#include "fastdeploy/vision/megvii/yolox.h"
#include "fastdeploy/vision/meituan/yolov6.h"

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindDeepinsight(pybind11::module& m) {
auto deepinsight_module = m.def_submodule(
"deepinsight", "https://github.com/deepinsight");
pybind11::class_<vision::deepinsight::SCRFD, FastDeployModel>(
deepinsight_module, "SCRFD")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::SCRFD& self, pybind11::array& data,
float conf_threshold, float nms_iou_threshold) {
auto mat = PyArrayToCvMat(data);
vision::FaceDetectionResult res;
self.Predict(&mat, &res, conf_threshold, nms_iou_threshold);
return res;
})
.def_readwrite("size", &vision::deepinsight::SCRFD::size)
.def_readwrite("padding_value",
&vision::deepinsight::SCRFD::padding_value)
.def_readwrite("is_mini_pad", &vision::deepinsight::SCRFD::is_mini_pad)
.def_readwrite("is_no_pad", &vision::deepinsight::SCRFD::is_no_pad)
.def_readwrite("is_scale_up", &vision::deepinsight::SCRFD::is_scale_up)
.def_readwrite("stride", &vision::deepinsight::SCRFD::stride)
.def_readwrite("use_kps", &vision::deepinsight::SCRFD::use_kps)
.def_readwrite("max_nms", &vision::deepinsight::SCRFD::max_nms)
.def_readwrite("downsample_strides",
&vision::deepinsight::SCRFD::downsample_strides)
.def_readwrite("num_anchors", &vision::deepinsight::SCRFD::num_anchors)
.def_readwrite("landmarks_per_face",
&vision::deepinsight::SCRFD::landmarks_per_face);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,363 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/scrfd.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
void SCRFD::LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto,
bool scale_fill, bool scale_up, int stride) {
float scale =
std::min(size[1] * 1.0 / mat->Height(), size[0] * 1.0 / mat->Width());
if (!scale_up) {
scale = std::min(scale, 1.0f);
}
int resize_h = int(round(mat->Height() * scale));
int resize_w = int(round(mat->Width() * scale));
int pad_w = size[0] - resize_w;
int pad_h = size[1] - resize_h;
if (_auto) {
pad_h = pad_h % stride;
pad_w = pad_w % stride;
} else if (scale_fill) {
pad_h = 0;
pad_w = 0;
resize_h = size[1];
resize_w = size[0];
}
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}
if (pad_h > 0 || pad_w > 0) {
float half_h = pad_h * 1.0 / 2;
int top = int(round(half_h - 0.1));
int bottom = int(round(half_h + 0.1));
float half_w = pad_w * 1.0 / 2;
int left = int(round(half_w - 0.1));
int right = int(round(half_w + 0.1));
Pad::Run(mat, top, bottom, left, right, color);
}
}
SCRFD::SCRFD(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option, const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool SCRFD::Initialize() {
// parameters for preprocess
use_kps = true;
size = {640, 640};
padding_value = {0.0, 0.0, 0.0};
is_mini_pad = false;
is_no_pad = false;
is_scale_up = false;
stride = 32;
downsample_strides = {8, 16, 32};
num_anchors = 2;
landmarks_per_face = 5;
center_points_is_update_ = false;
max_nms = 30000;
// num_outputs = use_kps ? 9 : 6;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}
bool SCRFD::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info) {
float ratio = std::min(size[1] * 1.0f / static_cast<float>(mat->Height()),
size[0] * 1.0f / static_cast<float>(mat->Width()));
if (ratio != 1.0) {
int interp = cv::INTER_AREA;
if (ratio > 1.0) {
interp = cv::INTER_LINEAR;
}
int resize_h = int(mat->Height() * ratio);
int resize_w = int(mat->Width() * ratio);
Resize::Run(mat, resize_w, resize_h, -1, -1, interp);
}
// scrfd's preprocess steps
// 1. letterbox
// 2. BGR->RGB
// 3. HWC->CHW
SCRFD::LetterBox(mat, size, padding_value, is_mini_pad, is_no_pad,
is_scale_up, stride);
BGR2RGB::Run(mat);
// Normalize::Run(mat, std::vector<float>(mat->Channels(), 0.0),
// std::vector<float>(mat->Channels(), 1.0));
// Compute `result = mat * alpha + beta` directly by channel
// Original Repo/tools/scrfd.py: cv2.dnn.blobFromImage(img, 1.0/128,
// input_size, (127.5, 127.5, 127.5), swapRB=True)
std::vector<float> alpha = {1.f / 128.f, 1.f / 128.f, 1.f / 128.f};
std::vector<float> beta = {-127.5f / 128.f, -127.5f / 128.f, -127.5f / 128.f};
Convert::Run(mat, alpha, beta);
// Record output shape of preprocessed image
(*im_info)["output_shape"] = {static_cast<float>(mat->Height()),
static_cast<float>(mat->Width())};
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}
void SCRFD::GeneratePoints() {
if (center_points_is_update_ && !is_dynamic_input_) {
return;
}
// 8, 16, 32
for (auto local_stride : downsample_strides) {
unsigned int num_grid_w = size[0] / local_stride;
unsigned int num_grid_h = size[1] / local_stride;
// y
for (unsigned int i = 0; i < num_grid_h; ++i) {
// x
for (unsigned int j = 0; j < num_grid_w; ++j) {
// num_anchors, col major
for (unsigned int k = 0; k < num_anchors; ++k) {
SCRFDPoint point;
point.cx = static_cast<float>(j);
point.cy = static_cast<float>(i);
center_points_[local_stride].push_back(point);
}
}
}
}
center_points_is_update_ = true;
}
bool SCRFD::Postprocess(
std::vector<FDTensor>& infer_result, FaceDetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold) {
// number of downsample_strides
int fmc = downsample_strides.size();
// scrfd has 6,9,10,15 output tensors
FDASSERT((infer_result.size() == 9 || infer_result.size() == 6 ||
infer_result.size() == 10 || infer_result.size() == 15),
"The default number of output tensor must be 6, 9, 10, or 15 "
"according to scrfd.");
FDASSERT((fmc == 3 || fmc == 5), "The fmc must be 3 or 5");
FDASSERT((infer_result.at(0).shape[0] == 1), "Only support batch =1 now.");
for (int i = 0; i < fmc; ++i) {
if (infer_result.at(i).dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
}
int total_num_boxes = 0;
// compute the reserve space.
for (int f = 0; f < fmc; ++f) {
total_num_boxes += infer_result.at(f).shape[1];
};
GeneratePoints();
result->Clear();
// scale the boxes to the origin image shape
auto iter_out = im_info.find("output_shape");
auto iter_ipt = im_info.find("input_shape");
FDASSERT(iter_out != im_info.end() && iter_ipt != im_info.end(),
"Cannot find input_shape or output_shape from im_info.");
float out_h = iter_out->second[0];
float out_w = iter_out->second[1];
float ipt_h = iter_ipt->second[0];
float ipt_w = iter_ipt->second[1];
float scale = std::min(out_h / ipt_h, out_w / ipt_w);
float pad_h = (out_h - ipt_h * scale) / 2.0f;
float pad_w = (out_w - ipt_w * scale) / 2.0f;
if (is_mini_pad) {
// 和 LetterBox中_auto=true的处理逻辑对应
pad_h = static_cast<float>(static_cast<int>(pad_h) % stride);
pad_w = static_cast<float>(static_cast<int>(pad_w) % stride);
}
// must be setup landmarks_per_face before reserve
result->landmarks_per_face = landmarks_per_face;
result->Reserve(total_num_boxes);
unsigned int count = 0;
// loop each stride
for (int f = 0; f < fmc; ++f) {
float* score_ptr = static_cast<float*>(infer_result.at(f).Data());
float* bbox_ptr = static_cast<float*>(infer_result.at(f + fmc).Data());
const unsigned int num_points = infer_result.at(f).shape[1];
int current_stride = downsample_strides[f];
auto& stride_points = center_points_[current_stride];
// loop each anchor
for (unsigned int i = 0; i < num_points; ++i) {
const float cls_conf = score_ptr[i];
if (cls_conf < conf_threshold) continue; // filter
auto& point = stride_points.at(i);
const float cx = point.cx; // cx
const float cy = point.cy; // cy
// bbox
const float* offsets = bbox_ptr + i * 4;
float l = offsets[0]; // left
float t = offsets[1]; // top
float r = offsets[2]; // right
float b = offsets[3]; // bottom
float x1 =
((cx - l) * static_cast<float>(current_stride) - static_cast<float>(pad_w)) / scale; // cx - l x1
float y1 =
((cy - t) * static_cast<float>(current_stride) - static_cast<float>(pad_h)) / scale; // cy - t y1
float x2 =
((cx + r) * static_cast<float>(current_stride) - static_cast<float>(pad_w)) / scale; // cx + r x2
float y2 =
((cy + b) * static_cast<float>(current_stride) - static_cast<float>(pad_h)) / scale; // cy + b y2
result->boxes.emplace_back(std::array<float, 4>{x1, y1, x2, y2});
result->scores.push_back(cls_conf);
if (use_kps) {
float* landmarks_ptr =
static_cast<float*>(infer_result.at(f + 2 * fmc).Data());
// landmarks
const float* kps_offsets = landmarks_ptr + i * (landmarks_per_face * 2);
for (unsigned int j = 0; j < landmarks_per_face * 2; j += 2) {
float kps_l = kps_offsets[j];
float kps_t = kps_offsets[j + 1];
float kps_x = ((cx + kps_l) * static_cast<float>(current_stride) - static_cast<float>(pad_w)) /
scale; // cx + l x
float kps_y = ((cy + kps_t) * static_cast<float>(current_stride) - static_cast<float>(pad_h)) /
scale; // cy + t y
result->landmarks.emplace_back(std::array<float, 2>{kps_x, kps_y});
}
}
count += 1; // limit boxes for nms.
if (count > max_nms) {
break;
}
}
}
// fetch original image shape
FDASSERT((iter_ipt != im_info.end()),
"Cannot find input_shape from im_info.");
if (result->boxes.size() == 0) {
return true;
}
utils::NMS(result, nms_iou_threshold);
// scale and clip box
for (size_t i = 0; i < result->boxes.size(); ++i) {
result->boxes[i][0] = std::max(result->boxes[i][0], 0.0f);
result->boxes[i][1] = std::max(result->boxes[i][1], 0.0f);
result->boxes[i][2] = std::max(result->boxes[i][2], 0.0f);
result->boxes[i][3] = std::max(result->boxes[i][3], 0.0f);
result->boxes[i][0] = std::min(result->boxes[i][0], ipt_w - 1.0f);
result->boxes[i][1] = std::min(result->boxes[i][1], ipt_h - 1.0f);
result->boxes[i][2] = std::min(result->boxes[i][2], ipt_w - 1.0f);
result->boxes[i][3] = std::min(result->boxes[i][3], ipt_h - 1.0f);
}
// scale and clip landmarks
for (size_t i = 0; i < result->landmarks.size(); ++i) {
result->landmarks[i][0] = std::max(result->landmarks[i][0], 0.0f);
result->landmarks[i][1] = std::max(result->landmarks[i][1], 0.0f);
result->landmarks[i][0] = std::min(result->landmarks[i][0], ipt_w - 1.0f);
result->landmarks[i][1] = std::min(result->landmarks[i][1], ipt_h - 1.0f);
}
return true;
}
bool SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result,
float conf_threshold, float nms_iou_threshold) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
std::map<std::string, std::array<float, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
im_info["output_shape"] = {static_cast<float>(mat.Height()),
static_cast<float>(mat.Width())};
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors, result, im_info, conf_threshold,
nms_iou_threshold)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true;
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,122 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL SCRFD : public FastDeployModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
SCRFD(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
std::string ModelName() const { return "deepinsight/scrfd"; }
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
// conf_threshold 为后处理的参数
// nms_iou_threshold 为后处理的参数
virtual bool Predict(cv::Mat* im, FaceDetectionResult* result,
float conf_threshold = 0.25f,
float nms_iou_threshold = 0.4f);
// 以下为模型在预测时的一些参数,基本是前后处理所需
// 用户在创建模型后,可根据模型的要求,以及自己的需求
// 对参数进行修改
// tuple of (width, height), default (640, 640)
std::vector<int> size;
// downsample strides (namely, steps) for SCRFD to
// generate anchors, will take (8,16,32) as default values.
// padding value, size should be same with Channels
std::vector<float> padding_value;
// only pad to the minimum rectange which height and width is times of stride
bool is_mini_pad;
// while is_mini_pad = false and is_no_pad = true, will resize the image to
// the set size
bool is_no_pad;
// if is_scale_up is false, the input image only can be zoom out, the maximum
// resize scale cannot exceed 1.0
bool is_scale_up;
// padding stride, for is_mini_pad
int stride;
// for offseting the boxes by classes when using NMS
std::vector<int> downsample_strides;
// landmarks_per_face, default 5 in SCRFD
int landmarks_per_face;
// are the outputs of onnx file with key points features or not
bool use_kps;
// the upperbond number of boxes processed by nms.
int max_nms;
// number anchors of each stride
unsigned int num_anchors;
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize();
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
// im_info为预处理过程保存的数据在后处理中需要用到
bool Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<float, 2>>* im_info);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
// im_info 为预处理记录的信息后处理用于还原box
// conf_threshold 后处理时过滤box的置信度阈值
// nms_iou_threshold 后处理时NMS设定的iou阈值
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceDetectionResult* result,
const std::map<std::string, std::array<float, 2>>& im_info,
float conf_threshold, float nms_iou_threshold);
void GeneratePoints();
// 对图片进行LetterBox处理
// mat 为读取到的原图
// size 为输入模型的图像尺寸
void LetterBox(Mat* mat, const std::vector<int>& size,
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
bool is_dynamic_input_;
bool center_points_is_update_;
typedef struct {
float cx;
float cy;
} SCRFDPoint;
std::unordered_map<int, std::vector<SCRFDPoint>> center_points_;
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -118,6 +118,21 @@ bool YOLOv5Lite::Initialize() {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}

View File

@@ -126,6 +126,13 @@ class FASTDEPLOY_DECL YOLOv5Lite : public FastDeployModel {
void GenerateAnchors(const std::vector<int>& size,
const std::vector<int>& downsample_strides,
std::vector<Anchor>* anchors, const int num_anchors = 3);
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
};
} // namespace ppogg
} // namespace vision

View File

@@ -28,6 +28,7 @@ void BindRangiLyu(pybind11::module& m);
void BindLinzaer(pybind11::module& m);
void BindBiubug6(pybind11::module& m);
void BindPpogg(pybind11::module& m);
void BindDeepinsight(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
#endif
@@ -75,6 +76,7 @@ void BindVision(pybind11::module& m) {
BindLinzaer(m);
BindBiubug6(m);
BindPpogg(m);
BindDeepinsight(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif

View File

@@ -89,6 +89,21 @@ bool ScaledYOLOv4::Initialize() {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}

View File

@@ -90,6 +90,13 @@ class FASTDEPLOY_DECL ScaledYOLOv4 : public FastDeployModel {
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
};
} // namespace wongkinyiu
} // namespace vision

View File

@@ -87,6 +87,21 @@ bool YOLOR::Initialize() {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}
@@ -176,7 +191,7 @@ bool YOLOR::Postprocess(
float pad_h = (out_h - ipt_h * scale) / 2.0f;
float pad_w = (out_w - ipt_w * scale) / 2.0f;
if (is_mini_pad) {
// 和 LetterBox中_auto=true的处理逻辑对应
// 和 LetterBox中_auto=true的处理逻辑对应
pad_h = static_cast<float>(static_cast<int>(pad_h) % stride);
pad_w = static_cast<float>(static_cast<int>(pad_w) % stride);
}

View File

@@ -89,6 +89,13 @@ class FASTDEPLOY_DECL YOLOR : public FastDeployModel {
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
};
} // namespace wongkinyiu
} // namespace vision

View File

@@ -88,6 +88,21 @@ bool YOLOv7::Initialize() {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// Check if the input shape is dynamic after Runtime already initialized,
// Note that, We need to force is_mini_pad 'false' to keep static
// shape after padding (LetterBox) when the is_dynamic_shape is 'false'.
is_dynamic_input_ = false;
auto shape = InputInfoOfRuntime(0).shape;
for (int i = 0; i < shape.size(); ++i) {
// if height or width is dynamic
if (i >= 2 && shape[i] <= 0) {
is_dynamic_input_ = true;
break;
}
}
if (!is_dynamic_input_) {
is_mini_pad = false;
}
return true;
}
@@ -177,7 +192,7 @@ bool YOLOv7::Postprocess(
float pad_h = (out_h - ipt_h * scale) / 2.0f;
float pad_w = (out_w - ipt_w * scale) / 2.0f;
if (is_mini_pad) {
// 和 LetterBox中_auto=true的处理逻辑对应
// 和 LetterBox中_auto=true的处理逻辑对应
pad_h = static_cast<float>(static_cast<int>(pad_h) % stride);
pad_w = static_cast<float>(static_cast<int>(pad_w) % stride);
}

View File

@@ -89,6 +89,13 @@ class FASTDEPLOY_DECL YOLOv7 : public FastDeployModel {
const std::vector<float>& color, bool _auto,
bool scale_fill = false, bool scale_up = true,
int stride = 32);
// whether to inference with dynamic shape (e.g ONNX export with dynamic shape
// or not.)
// while is_dynamic_shape if 'false', is_mini_pad will force 'false'. This
// value will
// auto check by fastdeploy after the internal Runtime already initialized.
bool is_dynamic_input_;
};
} // namespace wongkinyiu
} // namespace vision

View File

@@ -0,0 +1,51 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
std::string model_file = "../resources/models/SCRFD.onnx";
std::string img_path = "../resources/images/test_face_det.jpg";
std::string vis_path = "../resources/outputs/deepsight_scrfd_vis_result.jpg";
auto model = vis::deepinsight::SCRFD(model_file);
model.size = {640, 640}; // (width, height)
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
cv::Mat im = cv::imread(img_path);
cv::Mat vis_im = im.clone();
vis::FaceDetectionResult res;
if (!model.Predict(&im, &res, 0.3f, 0.3f)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
} else {
std::cout << "Prediction Done!" << std::endl;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f);
cv::imwrite(vis_path, vis_im);
std::cout << "Detect Done! Saved: " << vis_path << std::endl;
return 0;
}

View File

@@ -27,3 +27,4 @@ from . import rangilyu
from . import linzaer
from . import biubug6
from . import ppogg
from . import deepinsight

View File

@@ -0,0 +1,158 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from ... import FastDeployModel, Frontend
from ... import fastdeploy_main as C
class SCRFD(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(SCRFD, self).__init__(runtime_option)
self._model = C.vision.deepinsight.SCRFD(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "SCRFD initialize failed."
def predict(self, input_image, conf_threshold=0.7, nms_iou_threshold=0.3):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# 一些跟SCRFD模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [640, 640]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def padding_value(self):
return self._model.padding_value
@property
def is_no_pad(self):
return self._model.is_no_pad
@property
def is_mini_pad(self):
return self._model.is_mini_pad
@property
def is_scale_up(self):
return self._model.is_scale_up
@property
def stride(self):
return self._model.stride
@property
def downsample_strides(self):
return self._model.downsample_strides
@property
def landmarks_per_face(self):
return self._model.landmarks_per_face
@property
def use_kps(self):
return self._model.use_kps
@property
def max_nms(self):
return self._model.max_nms
@property
def num_anchors(self):
return self._model.num_anchors
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@padding_value.setter
def padding_value(self, value):
assert isinstance(
value,
list), "The value to set `padding_value` must be type of list."
self._model.padding_value = value
@is_no_pad.setter
def is_no_pad(self, value):
assert isinstance(
value, bool), "The value to set `is_no_pad` must be type of bool."
self._model.is_no_pad = value
@is_mini_pad.setter
def is_mini_pad(self, value):
assert isinstance(
value,
bool), "The value to set `is_mini_pad` must be type of bool."
self._model.is_mini_pad = value
@is_scale_up.setter
def is_scale_up(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_up` must be type of bool."
self._model.is_scale_up = value
@stride.setter
def stride(self, value):
assert isinstance(
value, int), "The value to set `stride` must be type of int."
self._model.stride = value
@downsample_strides.setter
def downsample_strides(self, value):
assert isinstance(
value,
list), "The value to set `downsample_strides` must be type of list."
self._model.downsample_strides = value
@landmarks_per_face.setter
def landmarks_per_face(self, value):
assert isinstance(
value,
int), "The value to set `landmarks_per_face` must be type of int."
self._model.landmarks_per_face = value
@use_kps.setter
def use_kps(self, value):
assert isinstance(
value, bool), "The value to set `use_kps` must be type of bool."
self._model.use_kps = value
@max_nms.setter
def max_nms(self, value):
assert isinstance(
value, int), "The value to set `max_nms` must be type of int."
self._model.max_nms = value
@num_anchors.setter
def num_anchors(self, value):
assert isinstance(
value, int), "The value to set `num_anchors` must be type of int."
self._model.num_anchors = value

View File

@@ -0,0 +1,92 @@
# 编译SCRFD示例
当前支持模型版本为:[SCRFD CID:17cdeab](https://github.com/deepinsight/insightface/tree/17cdeab12a35efcebc2660453a8cbeae96e20950)
本文档说明如何进行[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)的快速部署推理。本目录结构如下
```
.
├── cpp
│   ├── CMakeLists.txt
│   ├── README.md
│   └── scrfd.cc
├── README.md
└── scrfd.py
```
## 获取ONNX文件
- 手动获取
访问[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)官方github库按照指引下载安装下载`scrfd.pt` 模型,利用 `tools/scrfd2onnx.py` 得到`onnx`格式文件。
```
#下载scrfd模型文件
e.g. download from https://onedrive.live.com/?authkey=%21ABbFJx2JMhNjhNA&id=4A83B6B633B029CC%215542&cid=4A83B6B633B029CC
# 安装官方库配置环境,此版本导出环境为:
- 手动配置环境
torch==1.8.0
mmcv==1.3.5
mmdet==2.7.0
- 通过docker配置
docker pull qyjdefdocker/onnx-scrfd-converter:v0.3
# 导出onnx格式文件
- 手动生成
python tools/scrfd2onnx.py configs/scrfd/scrfd_500m.py weights/scrfd_500m.pth --shape 640 --input-img face-xxx.jpg
- docker
docker的onnx目录中已有生成好的onnx文件
# 移动onnx文件到demo目录
cp PATH/TO/SCRFD.onnx PATH/TO/model_zoo/vision/scrfd/
```
## 安装FastDeploy
使用如下命令安装FastDeploy注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu`
```
# 安装fastdeploy-python工具
pip install fastdeploy-python
# 安装vision-cpu模块
fastdeploy install vision-cpu
```
## Python部署
执行如下代码即会自动下载测试图片
```
python scrfd.py
```
执行完成后会将可视化结果保存在本地`vis_result.jpg`,同时输出检测结果如下
```
FaceDetectionResult: [xmin, ymin, xmax, ymax, score]
437.670410,194.262772, 478.729828, 244.633911, 0.912465
418.303650,118.277687, 455.877838, 169.209564, 0.911748
269.449493,280.810608, 319.466614, 342.681213, 0.908530
775.553955,237.509979, 814.626526, 286.252350, 0.901296
565.155945,303.849670, 608.786255, 356.025726, 0.898307
411.813477,296.117584, 454.560394, 353.151367, 0.889968
688.620239,153.063812, 728.825195, 204.860321, 0.888146
686.523071,304.881104, 732.901245, 364.715088, 0.885789
194.658829,236.657883, 234.194748, 289.099701, 0.881143
137.273422,286.025787, 183.479523, 344.614441, 0.877399
289.256775,148.388992, 326.087769, 197.035645, 0.875090
182.943939,154.105682, 221.422440, 204.460495, 0.871119
330.301849,207.786499, 367.546692, 260.813232, 0.869559
659.884216,254.861847, 701.580017, 307.984711, 0.869249
550.305359,232.336868, 591.702026, 281.101532, 0.866158
567.473511,127.402367, 604.959839, 175.831696, 0.858938
```
## 其它文档
- [C++部署](./cpp/README.md)
- [SCRFD API文档](./api.md)

View File

@@ -0,0 +1,71 @@
# SCRFD API说明
## Python API
### SCRFD类
```
fastdeploy.vision.deepinsight.SCRFD(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
```
SCRFD模型加载和初始化当model_format为`fd.Frontend.ONNX`只需提供model_file`SCRFD.onnx`当model_format为`fd.Frontend.PADDLE`则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### predict函数
> ```
> SCRFD.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
> ```
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
示例代码参考[scrfd.py](./scrfd.py)
## C++ API
### SCRFD类
```
fastdeploy::vision::deepinsight::SCRFD(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX)
```
SCRFD模型加载和初始化当model_format为`Frontend::ONNX`只需提供model_file`SCRFD.onnx`当model_format为`Frontend::PADDLE`则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### Predict函数
> ```
> SCRFD::Predict(cv::Mat* im, FaceDetectionResult* result,
> float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5)
> ```
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度
> > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值
示例代码参考[cpp/scrfd.cc](cpp/scrfd.cc)
## 其它API使用
- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md)

View File

@@ -0,0 +1,17 @@
PROJECT(scrfd_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.16)
# 在低版本ABI环境中通过如下代码进行兼容性编译
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(scrfd_demo ${PROJECT_SOURCE_DIR}/scrfd.cc)
# 添加FastDeploy库依赖
target_link_libraries(scrfd_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,76 @@
# 编译SCRFD示例
当前支持模型版本为:[SCRFD CID:17cdeab](https://github.com/deepinsight/insightface/tree/17cdeab12a35efcebc2660453a8cbeae96e20950)
本文档说明如何进行[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)的快速部署推理。本目录结构如下
## 获取ONNX文件
- 手动获取
访问[SCRFD](https://github.com/deepinsight/insightface/tree/master/detection/scrfd)官方github库按照指引下载安装下载`scrfd.pt` 模型,利用 `tools/scrfd2onnx.py` 得到`onnx`格式文件。
```
#下载scrfd模型文件
e.g. download from https://onedrive.live.com/?authkey=%21ABbFJx2JMhNjhNA&id=4A83B6B633B029CC%215542&cid=4A83B6B633B029CC
# 安装官方库配置环境,此版本导出环境为:
- 手动配置环境
torch==1.8.0
mmcv==1.3.5
mmdet==2.7.0
- 通过docker配置
docker pull qyjdefdocker/onnx-scrfd-converter:v0.3
# 导出onnx格式文件
- 手动生成
python tools/scrfd2onnx.py configs/scrfd/scrfd_500m.py weights/scrfd_500m.pth --shape 640 --input-img face-xxx.jpg
- docker
docker的onnx目录中已有生成好的onnx文件
## 运行demo
```
# 下载和解压预测库
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.0.3.tgz
tar xvf fastdeploy-linux-x64-0.0.3.tgz
# 编译示例代码
mkdir build & cd build
cmake ..
make -j
# 移动onnx文件到demo目录
cp PATH/TO/SCRFD.onnx PATH/TO/model_zoo/vision/scrfd/cpp/build/
# 下载图片
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
# 执行
./scrfd_demo
```
执行完后可视化的结果保存在本地`vis_result.jpg`,同时会将检测框输出在终端,如下所示
```
FaceDetectionResult: [xmin, ymin, xmax, ymax, score]
437.670410,194.262772, 478.729828, 244.633911, 0.912465
418.303650,118.277687, 455.877838, 169.209564, 0.911748
269.449493,280.810608, 319.466614, 342.681213, 0.908530
775.553955,237.509979, 814.626526, 286.252350, 0.901296
565.155945,303.849670, 608.786255, 356.025726, 0.898307
411.813477,296.117584, 454.560394, 353.151367, 0.889968
688.620239,153.063812, 728.825195, 204.860321, 0.888146
686.523071,304.881104, 732.901245, 364.715088, 0.885789
194.658829,236.657883, 234.194748, 289.099701, 0.881143
137.273422,286.025787, 183.479523, 344.614441, 0.877399
289.256775,148.388992, 326.087769, 197.035645, 0.875090
182.943939,154.105682, 221.422440, 204.460495, 0.871119
330.301849,207.786499, 367.546692, 260.813232, 0.869559
659.884216,254.861847, 701.580017, 307.984711, 0.869249
550.305359,232.336868, 591.702026, 281.101532, 0.866158
567.473511,127.402367, 604.959839, 175.831696, 0.858938
```

View File

@@ -0,0 +1,44 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
int main() {
namespace vis = fastdeploy::vision;
auto model = vis::deepinsight::SCRFD("SCRFD.onnx");
if (!model.Initialized()) {
std::cerr << "Init Failed." << std::endl;
return -1;
}
cv::Mat im = cv::imread("test_lite_face_detector_3.jpg");
cv::Mat vis_im = im.clone();
// 如果导入不带有关键点预测的模型,请修改模型参数 use_kps 和 landmarks_per_face示例如下
// model.landmarks_per_face = 0;
// model.use_kps = false;
vis::FaceDetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
// 输出预测框结果
std::cout << res.Str() << std::endl;
// 可视化预测结果
vis::Visualize::VisFaceDetection(&vis_im, res, 2, 0.3f);
cv::imwrite("vis_result.jpg", vis_im);
return 0;
}

View File

@@ -0,0 +1,25 @@
import fastdeploy as fd
import cv2
# 下载模型和测试图片
test_jpg_url = "https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg"
fd.download(test_jpg_url, ".", show_progress=True)
# 加载模型
model = fd.vision.deepinsight.SCRFD("SCRFD.onnx")
# 如果导入不带有关键点预测的模型,请修改模型参数 use_kps 和 landmarks_per_face示例如下
# model.use_kps = False
# model.landmarks_per_face = 0
# 预测图片
im = cv2.imread("test_lite_face_detector_3.jpg")
result = model.predict(im, conf_threshold=0.5, nms_iou_threshold=0.5)
# 可视化结果
fd.vision.visualize.vis_face_detection(im, result)
cv2.imwrite("vis_result.jpg", im)
# 输出预测结果
print(result)
print(model.runtime_option)