[Model] Support BlazeFace Model (#1172)

* fit yolov7face file path

* TODO:添加yolov7facePython接口Predict

* resolve yolov7face.py

* resolve yolov7face.py

* resolve yolov7face.py

* add yolov7face example readme file

* [Doc] fix yolov7face example readme file

* [Doc]fix yolov7face example readme file

* support BlazeFace

* add blazeface readme file

* fix review problem

* fix code style error

* fix review problem

* fix review problem

* fix head file problem

* fix review problem

* fix review problem

* fix readme file problem

* add English readme file

* fix English readme file
This commit is contained in:
CoolCola
2023-02-06 14:24:12 +08:00
committed by GitHub
parent e2de3f36d3
commit 42d14e7119
21 changed files with 1518 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
English | [简体中文](README_CN.md)
# BlazeFace Ready-to-deploy Model
- BlazeFace deployment model implementation comes from [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),and [Pre-training model based on WiderFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
- 1Provided in [Official library
](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools) *.params, could deploy after operation [export_model.py](#Export PADDLE model);
- 2Developers can train BlazeFace model based on their own data according to [export_model. py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)After exporting the model, complete the deployment。
## Export PADDLE model
Visit [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) Github library, download and install according to the instructions, download the `. yml` and `. params` model parameters, and use` export_ Model. py `gets the` pad `model file`. yml,. pdiparams,. pdmodel `.
* Download BlazeFace model parameter file
|Network structure | input size | number of pictures/GPU | learning rate strategy | Easy/Media/Hard Set | prediction delay (SD855) | model size (MB) | download | configuration file|
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
* Export paddle-format file
```bash
python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
```
## Detailed Deployment Tutorials
- [Python Deployment](python)
- [C++ Deployment](cpp)
## Release Note
- This tutorial and related code are written based on [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)

View File

@@ -0,0 +1,31 @@
# BlazeFace准备部署模型
- BlazeFace部署模型实现来自[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),和[基于WiderFace的预训练模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
- 1[官方库](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools)中提供的*.params,通过[export_model.py](#导出PADDLE模型)操作后,可进行部署;
- 2开发者基于自己数据训练的BlazeFace模型可按照[export_model.py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)导出模型后,完成部署。
## 导出PADDLE模型
访问[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)github库按照指引下载安装下载`.yml``.params` 模型参数,利用 `export_model.py` 得到`paddle`模型文件`.yml, .pdiparams, .pdmodel`
* 下载BlazeFace模型参数文件
| 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延SD855| 模型大小(MB) | 下载 | 配置文件 |
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
* 导出paddle格式文件
```bash
python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
```
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)
## 版本说明
- 本版本文档和代码基于[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) 编写

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# Specifies the path to the fastdeploy library after you have downloaded it
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(../../../../../FastDeploy.cmake)
# Add the FastDeploy dependency header
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# Add the FastDeploy library dependency
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,78 @@
English | [简体中文](README_CN.md)
# BlazeFace C++ Deployment Example
This directory provides examples that `infer.cc` fast finishes the deployment of BlazeFace on CPU/GPU。
Before deployment, two steps require confirmation
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
Taking the CPU inference on Linux as an example, the compilation test can be completed by executing the following command in this directory.
```bash
mkdir build
cd build
# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
make -j
#Download the official converted YOLOv7Face model files and test images
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
#Use blazeface-1000e model
# CPU inference
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
# GPU Inference
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
```
The visualized result after running is as follows
<img width="640" src="https://user-images.githubusercontent.com/49013063/206170111-843febb6-67d6-4c46-a121-d87d003bba21.jpg">
The above command works for Linux or MacOS. For SDK use-pattern in Windows, refer to:
- [How to use FastDeploy C++ SDK in Windows](../../../../../docs/cn/faq/use_sdk_on_windows.md)
## BlazeFace C++ Interface
### BlazeFace Class
```c++
fastdeploy::vision::facedet::BlazeFace(
const string& model_file,
const string& params_file = "",
const string& config_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
**Parameter**
> * **model_file**(str): Model file path
> * **params_file**(str): Parameter file path. Only passing an empty string when the model is in PADDLE format
> * **config_file**(str): Config file path. Only passing an empty string when the model is in PADDLE format
> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
> * **model_format**(ModelFormat): Model format. PADDLE format by default
#### Predict Function
> ```c++
> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
> ```
>
> Model prediction interface. Input images and output detection results.
>
> **Parameter**
>
> > * **im**: Input images in HWC or BGR format
> > * **result**: Detection results, including detection box and confidence of each box. Refer to [Vision Model Prediction Result](../../../../../docs/api/vision_results/) for FaceDetectionResult
- [Model Description](../../)
- [Python Deployment](../python)
- [Vision Model Prediction Results](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,77 @@
[English](README.md) | 简体中文
# BlazeFace C++部署示例
本目录下提供`infer.cc`快速完成BlazeFace在CPU/GPU部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上CPU推理为例在本目录执行如下命令即可完成编译测试
```bash
mkdir build
cd build
# 下载FastDeploy预编译库用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
make -j
#下载官方转换好的BlazeFace模型文件和测试图片
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
#使用blazeface-1000e模型
# CPU推理
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
# GPU推理
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/49013063/206170111-843febb6-67d6-4c46-a121-d87d003bba21.jpg">
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
## BlazeFace C++接口
### BlazeFace类
```c++
fastdeploy::vision::facedet::BlazeFace(
const string& model_file,
const string& params_file = "",
const string& config_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
BlazeFace模型加载和初始化其中model_file为导出的PADDLE模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **config_file**(str): 配置文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为PADDLE格式
#### Predict函数
> ```c++
> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, FaceDetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseCpu();
auto model = fastdeploy::vision::facedet::BlazeFace(
model_file, params_file, config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
fastdeploy::vision::FaceDetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::facedet::BlazeFace(
model_file, params_file, config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
fastdeploy::vision::FaceDetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_model yolov5s-face.onnx ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -0,0 +1,68 @@
English | [简体中文](README_CN.md)
# BlazeFace Python Deployment Example
Before deployment, two steps require confirmation
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
- 2. Install FastDeploy Python whl package. Refer to [FastDeploy Python Installation](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
This directory provides examples that `infer.py` fast finishes the deployment of BlazeFace on CPU/GPU.
```bash
# Download the example code for deployment
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/facedet/blazeface/python/
# Download BlazeFace model files and test images
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
# Use blazeface-1000e model
# CPU Inference
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
# GPU Inference
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
```
The visualized result after running is as follows
<img width="640" src="https://user-images.githubusercontent.com/67993288/184301839-a29aefae-16c9-4196-bf9d-9c6cf694f02d.jpg">
## BlazeFace Python Interface
```python
fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
```
BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
**Parameter**
> * **model_file**(str): Model file path
> * **params_file**(str): Parameter file path. No need to set when the model is in PADDLE format
> * **config_file**(str): config file path. No need to set when the model is in PADDLE format
> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
> * **model_format**(ModelFormat): Model format. PADDLE format by default
### predict function
> ```python
> BlazeFace.predict(input_image)
> ```
> Through let BlazeFace.postprocessor.conf_threshold = 0.2to modify conf_threshold
>
> Model prediction interface. Input images and output detection results.
>
> **Parameter**
>
> > * **input_image**(np.ndarray): Input image in HWC or BGR format
> **Return**
>
> > Return`fastdeploy.vision.FaceDetectionResult` structure. Refer to [Vision Model Prediction Results](../../../../../docs/api/vision_results/) for its description.
## Other Documents
- [BlazeFace Model Description](..)
- [BlazeFace C++ Deployment](../cpp)
- [Model Prediction Results](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,68 @@
[English](README.md) | 简体中文
# BlazeFace Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供`infer.py`快速完成BlazeFace在CPU/GPU部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/facedet/blazeface/python/
#下载BlazeFace模型文件和测试图片
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
#使用blazeface-1000e模型
# CPU推理
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
# GPU推理
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
```
运行完成可视化结果如下图所示
<img width="640" src="https://user-images.githubusercontent.com/67993288/184301839-a29aefae-16c9-4196-bf9d-9c6cf694f02d.jpg">
## BlazeFace Python接口
```python
fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
```
BlazeFace模型加载和初始化
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **config_file**(str): config文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为PADDLE
### predict函数
> ```python
> BlazeFace.predict(input_image)
> ```
> 通过BlazeFace.postprocessor.conf_threshold = 0.2来修改conf_threshold
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式
> **返回**
>
> > 返回`fastdeploy.vision.FaceDetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [BlazeFace 模型介绍](..)
- [BlazeFace C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)

View File

@@ -0,0 +1,58 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of blazeface model dir.")
parser.add_argument(
"--image", required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("images", [1, 3, 640, 640])
return option
args = parse_arguments()
model_dir = args.model
model_file = os.path.join(model_dir, "model.pdmodel")
params_file = os.path.join(model_dir, "model.pdiparams")
config_file = os.path.join(model_dir, "infer_cfg.yml")
# Configure runtime and load the model
runtime_option = build_option(args)
model = fd.vision.facedet.BlazeFace(model_file, params_file, config_file, runtime_option=runtime_option)
# Predict image detection results
im = cv2.imread(args.image)
result = model.predict(im)
print(result)
# Visualization of prediction Results
vis_im = fd.vision.vis_face_detection(im, result)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -41,6 +41,7 @@
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
#include "fastdeploy/vision/facedet/contrib/yolov5face.h"
#include "fastdeploy/vision/facedet/contrib/yolov7face/yolov7face.h"
#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
#include "fastdeploy/vision/faceid/contrib/insightface/model.h"
#include "fastdeploy/vision/faceid/contrib/adaface/adaface.h"
#include "fastdeploy/vision/headpose/contrib/fsanet.h"

View File

@@ -20,6 +20,7 @@ void BindRetinaFace(pybind11::module& m);
void BindUltraFace(pybind11::module& m);
void BindYOLOv5Face(pybind11::module& m);
void BindYOLOv7Face(pybind11::module& m);
void BindBlazeFace(pybind11::module& m);
void BindSCRFD(pybind11::module& m);
void BindFaceDet(pybind11::module& m) {
@@ -28,6 +29,7 @@ void BindFaceDet(pybind11::module& m) {
BindUltraFace(facedet_module);
BindYOLOv5Face(facedet_module);
BindYOLOv7Face(facedet_module);
BindBlazeFace(facedet_module);
BindSCRFD(facedet_module);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,93 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy{
namespace vision{
namespace facedet{
BlazeFace::BlazeFace(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format)
: preprocessor_(config_file){
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::OPENVINO, Backend::LITE, Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool BlazeFace::Initialize(){
if (!InitRuntime()){
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool BlazeFace::Predict(const cv::Mat& im, FaceDetectionResult* result){
std::vector<FaceDetectionResult> results;
if (!this->BatchPredict({im}, &results)) {
return false;
}
*result = std::move(results[0]);
return true;
}
bool BlazeFace::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<FaceDetectionResult>* results){
std::vector<FDMat> fd_images = WrapMat(images);
FDASSERT(images.size() == 1, "Only support batch = 1 now.");
std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
reused_input_tensors_[0].name = "image";
reused_input_tensors_[1].name = "scale_factor";
reused_input_tensors_[2].name = "im_shape";
// Some models don't need scale_factor and im_shape as input
while (reused_input_tensors_.size() != NumInputsOfRuntime()) {
reused_input_tensors_.pop_back();
}
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)){
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
return false;
}
return true;
}
} // namespace facedet
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
namespace fastdeploy {
namespace vision {
namespace facedet {
/*! @brief BlazeFace model object used when to load a BlazeFace model exported by BlazeFace.
*/
class FASTDEPLOY_DECL BlazeFace: public FastDeployModel{
public:
/** \brief Set path of model file and the configuration of runtime.
*
* \param[in] model_file Path of model file, e.g ./blazeface.onnx
* \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
* \param[in] model_format Model format of the loaded model, default is ONNX format
*/
BlazeFace(const std::string& model_file, const std::string& params_file = "",
const std::string& config_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
std::string ModelName() {return "blaze-face";}
/** \brief Predict the detection result for an input image
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output detection result will be writen to this structure
* \return true if the prediction successed, otherwise false
*/
bool Predict(const cv::Mat& im, FaceDetectionResult* result);
/** \brief Predict the detection results for a batch of input images
*
* \param[in] imgs, The input image list, each element comes from cv::imread()
* \param[in] results The output detection result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<FaceDetectionResult>* results);
/// Get preprocessor reference of BlazeFace
virtual BlazeFacePreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of BlazeFace
virtual BlazeFacePostprocessor& GetPostprocessor() {
return postprocessor_;
}
protected:
bool Initialize();
BlazeFacePreprocessor preprocessor_;
BlazeFacePostprocessor postprocessor_;
};
} // namespace facedet
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindBlazeFace(pybind11::module& m) {
pybind11::class_<vision::facedet::BlazeFacePreprocessor>(
m, "BlazeFacePreprocessor")
.def(pybind11::init<>())
.def("run", [](vision::facedet::BlazeFacePreprocessor& self, std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
if (!self.Run(&images, &outputs, &ims_info)) {
throw std::runtime_error("Failed to preprocess the input data in BlazeFacePreprocessor.");
}
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return make_pair(outputs, ims_info);
});
pybind11::class_<vision::facedet::BlazeFacePostprocessor>(
m, "BlazeFacePostprocessor")
.def(pybind11::init<>())
.def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector<FDTensor>& inputs,
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
std::vector<vision::FaceDetectionResult> results;
if (!self.Run(inputs, &results, ims_info)) {
throw std::runtime_error("Failed to postprocess the runtime result in BlazeFacePostprocessor.");
}
return results;
})
.def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector<pybind11::array>& input_array,
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
std::vector<vision::FaceDetectionResult> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results, ims_info)) {
throw std::runtime_error("Failed to postprocess the runtime result in BlazePostprocessor.");
}
return results;
})
.def_property("conf_threshold", &vision::facedet::BlazeFacePostprocessor::GetConfThreshold, &vision::facedet::BlazeFacePostprocessor::SetConfThreshold)
.def_property("nms_threshold", &vision::facedet::BlazeFacePostprocessor::GetNMSThreshold, &vision::facedet::BlazeFacePostprocessor::SetNMSThreshold);
pybind11::class_<vision::facedet::BlazeFace, FastDeployModel>(m, "BlazeFace")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::facedet::BlazeFace& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceDetectionResult res;
self.Predict(mat, &res);
return res;
})
.def("batch_predict", [](vision::facedet::BlazeFace& self, std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::FaceDetectionResult> results;
self.BatchPredict(images, &results);
return results;
})
.def_property_readonly("preprocessor", &vision::facedet::BlazeFace::GetPreprocessor)
.def_property_readonly("postprocessor", &vision::facedet::BlazeFace::GetPostprocessor);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,96 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
#include "fastdeploy/vision/utils/utils.h"
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
namespace fastdeploy {
namespace vision {
namespace facedet {
BlazeFacePostprocessor::BlazeFacePostprocessor() {
conf_threshold_ = 0.5;
nms_threshold_ = 0.3;
}
bool BlazeFacePostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<FaceDetectionResult>* results,
const std::vector<std::map<std::string,
std::array<float, 2>>>& ims_info) {
// Get number of boxes for each input image
std::vector<int> num_boxes(tensors[1].shape[0]);
int total_num_boxes = 0;
if (tensors[1].dtype == FDDataType::INT32) {
const auto* data = static_cast<const int32_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
total_num_boxes += num_boxes[i];
}
} else if (tensors[1].dtype == FDDataType::INT64) {
const auto* data = static_cast<const int64_t*>(tensors[1].CpuData());
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
num_boxes[i] = static_cast<int>(data[i]);
}
}
// Special case for TensorRT, it has fixed output shape of NMS
// So there's invalid boxes in its' output boxes
int num_output_boxes = static_cast<int>(tensors[0].Shape()[0]);
bool contain_invalid_boxes = false;
if (total_num_boxes != num_output_boxes) {
if (num_output_boxes % num_boxes.size() == 0) {
contain_invalid_boxes = true;
} else {
FDERROR << "Cannot handle the output data for this model, unexpected "
"situation."
<< std::endl;
return false;
}
}
// Get boxes for each input image
results->resize(num_boxes.size());
if (tensors[0].shape[0] == 0) {
// No detected boxes
return true;
}
const auto* box_data = static_cast<const float*>(tensors[0].CpuData());
int offset = 0;
for (size_t i = 0; i < num_boxes.size(); ++i) {
const float* ptr = box_data + offset;
(*results)[i].Reserve(num_boxes[i]);
for (size_t j = 0; j < num_boxes[i]; ++j) {
if (ptr[j * 6 + 1] > conf_threshold_) {
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
(*results)[i].boxes.emplace_back(std::array<float, 4>(
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
}
}
if (contain_invalid_boxes) {
offset += static_cast<int>(num_output_boxes * 6 / num_boxes.size());
} else {
offset += static_cast<int>(num_boxes[i] * 6);
}
}
return true;
}
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace facedet {
class FASTDEPLOY_DECL BlazeFacePostprocessor{
public:
/*! @brief Postprocessor object for BlazeFace serials model.
*/
BlazeFacePostprocessor();
/** \brief Process the result of runtime and fill to FaceDetectionResult structure
*
* \param[in] infer_result The inference result from runtime
* \param[in] results The output result of detection
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& infer_result,
std::vector<FaceDetectionResult>* results,
const std::vector<std::map<std::string,
std::array<float, 2>>>& ims_info);
/// Set conf_threshold, default 0.5
void SetConfThreshold(const float& conf_threshold) {
conf_threshold_ = conf_threshold;
}
/// Get conf_threshold, default 0.5
float GetConfThreshold() const { return conf_threshold_; }
/// Set nms_threshold, default 0.3
void SetNMSThreshold(const float& nms_threshold) {
nms_threshold_ = nms_threshold;
}
/// Get nms_threshold, default 0.3
float GetNMSThreshold() const { return nms_threshold_; }
protected:
float conf_threshold_;
float nms_threshold_;
};
} // namespace facedet
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,207 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "fastdeploy/function/pad.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace facedet {
BlazeFacePreprocessor::BlazeFacePreprocessor(const std::string& config_file) {
is_scale_ = false;
normalize_mean_ = {123, 117, 104};
normalize_std_ = {127.502231, 127.502231, 127.502231};
this->config_file_ = config_file;
FDASSERT(BuildPreprocessPipelineFromConfig(),
"Failed to create PaddleDetPreprocessor.");
}
bool BlazeFacePreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
std::vector<std::map<std::string, std::array<float, 2>>>* ims_info) {
if (images->size() == 0) {
FDERROR << "The size of input images should be greater than 0." << std::endl;
return false;
}
ims_info->resize(images->size());
outputs->resize(3);
int batch = static_cast<int>(images->size());
// Allocate memory for scale_factor
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
// Allocate memory for im_shape
(*outputs)[2].Resize({batch, 2}, FDDataType::FP32);
std::vector<int> max_hw({-1, -1});
auto* scale_factor_ptr =
reinterpret_cast<float*>((*outputs)[1].MutableData());
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
// Concat all the preprocessed data to a batch tensor
std::vector<FDTensor> im_tensors(images->size());
for (size_t i = 0; i < images->size(); ++i) {
int origin_w = (*images)[i].Width();
int origin_h = (*images)[i].Height();
scale_factor_ptr[2 * i] = 1.0;
scale_factor_ptr[2 * i + 1] = 1.0;
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl;
return false;
}
if (processors_[j]->Name().find("Resize") != std::string::npos) {
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
}
}
if ((*images)[i].Height() > max_hw[0]) {
max_hw[0] = (*images)[i].Height();
}
if ((*images)[i].Width() > max_hw[1]) {
max_hw[1] = (*images)[i].Width();
}
im_shape_ptr[2 * i] = max_hw[0];
im_shape_ptr[2 * i + 1] = max_hw[1];
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
// if the size of image less than max_hw, pad to max_hw
FDTensor tensor;
(*images)[i].ShareWithTensor(&tensor);
function::Pad(tensor, &(im_tensors[i]),
{0, 0, max_hw[0] - (*images)[i].Height(),
max_hw[1] - (*images)[i].Width()},
0);
} else {
// No need pad
(*images)[i].ShareWithTensor(&(im_tensors[i]));
}
// Reshape to 1xCxHxW
im_tensors[i].ExpandDim(0);
}
if (im_tensors.size() == 1) {
// If there's only 1 input, no need to concat
// skip memory copy
(*outputs)[0] = std::move(im_tensors[0]);
} else {
// Else concat the im tensor for each input image
// compose a batched input tensor
function::Concat(im_tensors, &((*outputs)[0]), 0);
}
return true;
}
bool BlazeFacePreprocessor::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
processors_.push_back(std::make_shared<BGR2RGB>());
bool has_permute = false;
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
has_permute = true;
continue;
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
if (has_permute) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
} // namespace facedet
} // namespace vision
} // namespacefastdeploy

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
namespace fastdeploy {
namespace vision {
namespace facedet {
class FASTDEPLOY_DECL BlazeFacePreprocessor:
public fastdeploy::vision::detection::PaddleDetPreprocessor {
public:
/** \brief Create a preprocessor instance for BlazeFace serials model
*/
BlazeFacePreprocessor() = default;
/** \brief Create a preprocessor instance for Blazeface serials model
*
* \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
*/
explicit BlazeFacePreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \ret
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
std::vector<std::map<std::string, std::array<float, 2>>>* ims_info);
private:
bool BuildPreprocessPipelineFromConfig();
// if is_scale_up is false, the input image only can be zoom out,
// the maximum resize scale cannot exceed 1.0
bool is_scale_;
std::vector<float> normalize_mean_;
std::vector<float> normalize_std_;
std::vector<std::shared_ptr<Processor>> processors_;
// read config file
std::string config_file_;
};
} // namespace facedet
} // namespace vision
} // namespace fastdeploy

View File

@@ -15,6 +15,7 @@
from __future__ import absolute_import
from .contrib.yolov5face import YOLOv5Face
from .contrib.yolov7face import *
from .contrib.blazeface import *
from .contrib.retinaface import RetinaFace
from .contrib.scrfd import SCRFD
from .contrib.ultraface import UltraFace

View File

@@ -0,0 +1,143 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class BlazeFacePreprocessor:
def __init__(self):
"""Create a preprocessor for BlazeFace
"""
self._preprocessor = C.vision.facedet.BlazeFacePreprocessor()
def run(self, input_ims):
"""Preprocess input images for BlazeFace
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
@property
def is_scale_(self):
"""
is_scale_ for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true
"""
return self._preprocessor.is_scale_
@is_scale_.setter
def is_scale_(self, value):
assert isinstance(
value,
bool), "The value to set `is_scale_` must be type of bool."
self._preprocessor.is_scale_ = value
class BlazeFacePostprocessor:
def __init__(self):
"""Create a postprocessor for BlazeFace
"""
self._postprocessor = C.vision.facedet.BlazeFacePostprocessor()
def run(self, runtime_results, ims_info):
"""Postprocess the runtime results for BlazeFace
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:param: ims_info: (list of dict)Record input_shape and output_shape
:return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results, ims_info)
@property
def conf_threshold(self):
"""
confidence threshold for postprocessing, default is 0.5
"""
return self._postprocessor.conf_threshold
@property
def nms_threshold(self):
"""
nms threshold for postprocessing, default is 0.3
"""
return self._postprocessor.nms_threshold
@conf_threshold.setter
def conf_threshold(self, conf_threshold):
assert isinstance(conf_threshold, float),\
"The value to set `conf_threshold` must be type of float."
self._postprocessor.conf_threshold = conf_threshold
@nms_threshold.setter
def nms_threshold(self, nms_threshold):
assert isinstance(nms_threshold, float),\
"The value to set `nms_threshold` must be type of float."
self._postprocessor.nms_threshold = nms_threshold
class BlazeFace(FastDeployModel):
def __init__(self,
model_file,
params_file="",
config_file="",
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a BlazeFace model exported by BlazeFace.
:param model_file: (str)Path of model file, e.g ./Blazeface.onnx
:param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(BlazeFace, self).__init__(runtime_option)
self._model = C.vision.facedet.BlazeFace(
model_file, params_file, config_file, self._runtime_option, model_format)
assert self.initialized, "BlazeFace initialize failed."
def predict(self, input_image):
"""Detect the location and key points of human faces from an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: FaceDetectionResult
"""
return self._model.predict(input_image)
def batch_predict(self, images):
"""Classify a batch of input image
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of FaceDetectionResult
"""
return self._model.batch_predict(images)
@property
def preprocessor(self):
"""Get BlazefacePreprocessor object of the loaded model
:return BlazefacePreprocessor
"""
return self._model.preprocessor
@property
def postprocessor(self):
"""Get BlazefacePostprocessor object of the loaded model
:return BlazefacePostprocessor
"""
return self._model.postprocessor

View File

@@ -0,0 +1,151 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastdeploy import ModelFormat
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
import runtime_config as rc
def test_detection_blazeface():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg"
result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
result_url2 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result2.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(input_url2, "resources")
model_dir = "resources/blazeface_1000e"
model_file = os.path.join(model_dir, "model.pdmodel")
params_file = os.path.join(model_dir, "model.pdiparams")
config_file = os.path.join(model_dir, "infer_cfg.yml")
model = fd.vision.facedet.BlazeFace(
model_file, params_file, config_file, runtime_option=rc.test_option)
model.postprocessor.conf_threshold = 0.5
with open("resources/blazeface_result1.pkl", "rb") as f:
expect1 = pickle.load(f)
with open("resources/blazeface_result2.pkl", "rb") as f:
expect2 = pickle.load(f)
im1 = cv2.imread("./resources/000000014439.jpg")
im2 = cv2.imread("./resources/000000570688.jpg")
for i in range(3):
# test single predict
result1 = model.predict(im1)
result2 = model.predict(im2)
diff_boxes_1 = np.fabs(
np.array(result1.boxes) - np.array(expect1["boxes"]))
diff_boxes_2 = np.fabs(
np.array(result2.boxes) - np.array(expect2["boxes"]))
diff_scores_1 = np.fabs(
np.array(result1.scores) - np.array(expect1["scores"]))
diff_scores_2 = np.fabs(
np.array(result2.scores) - np.array(expect2["scores"]))
assert diff_boxes_1.max(
) < 1e-04, "There's difference in detection boxes 1."
assert diff_scores_1.max(
) < 1e-04, "There's difference in detection score 1."
assert diff_boxes_2.max(
) < 1e-03, "There's difference in detection boxes 2."
assert diff_scores_2.max(
) < 1e-04, "There's difference in detection score 2."
print("one image test success!")
# test batch predict
results = model.batch_predict([im1, im2])
result1 = results[0]
result2 = results[1]
diff_boxes_1 = np.fabs(
np.array(result1.boxes) - np.array(expect1["boxes"]))
diff_boxes_2 = np.fabs(
np.array(result2.boxes) - np.array(expect2["boxes"]))
diff_scores_1 = np.fabs(
np.array(result1.scores) - np.array(expect1["scores"]))
diff_scores_2 = np.fabs(
np.array(result2.scores) - np.array(expect2["scores"]))
assert diff_boxes_1.max(
) < 1e-04, "There's difference in detection boxes 1."
assert diff_scores_1.max(
) < 1e-03, "There's difference in detection score 1."
assert diff_boxes_2.max(
) < 1e-04, "There's difference in detection boxes 2."
assert diff_scores_2.max(
) < 1e-04, "There's difference in detection score 2."
print("batch predict success!")
def test_detection_blazeface_runtime():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
fd.download_and_decompress(model_url, "resources")
fd.download(input_url1, "resources")
fd.download(result_url1, "resources")
model_dir = "resources/blazeface_1000e"
model_file = os.path.join(model_dir, "model.pdmodel")
params_file = os.path.join(model_dir, "model.pdiparams")
config_file = os.path.join(model_dir, "infer_cfg.yml")
preprocessor = fd.vision.facedet.BlazeFacePreprocessor()
postprocessor = fd.vision.facedet.BlazeFacePostprocessor()
rc.test_option.set_model_path(model_file, params_file, config_file, model_format=ModelFormat.PADDLE)
rc.test_option.use_openvino_backend()
runtime = fd.Runtime(rc.test_option)
with open("resources/blazeface_result1.pkl", "rb") as f:
expect1 = pickle.load(f)
im1 = cv2.imread("resources/000000014439.jpg")
for i in range(3):
# test runtime
input_tensors, ims_info = preprocessor.run([im1.copy()])
output_tensors = runtime.infer({"images": input_tensors[0]})
results = postprocessor.run(output_tensors, ims_info)
result1 = results[0]
diff_boxes_1 = np.fabs(
np.array(result1.boxes) - np.array(expect1["boxes"]))
diff_scores_1 = np.fabs(
np.array(result1.scores) - np.array(expect1["scores"]))
assert diff_boxes_1.max(
) < 1e-03, "There's difference in detection boxes 1."
assert diff_scores_1.max(
) < 1e-04, "There's difference in detection score 1."
if __name__ == "__main__":
test_detection_blazeface()
test_detection_blaze_runtime()