mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] Support BlazeFace Model (#1172)
* fit yolov7face file path * TODO:添加yolov7facePython接口Predict * resolve yolov7face.py * resolve yolov7face.py * resolve yolov7face.py * add yolov7face example readme file * [Doc] fix yolov7face example readme file * [Doc]fix yolov7face example readme file * support BlazeFace * add blazeface readme file * fix review problem * fix code style error * fix review problem * fix review problem * fix head file problem * fix review problem * fix review problem * fix readme file problem * add English readme file * fix English readme file
This commit is contained in:
34
examples/vision/facedet/blazeface/README.md
Normal file
34
examples/vision/facedet/blazeface/README.md
Normal file
@@ -0,0 +1,34 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
# BlazeFace Ready-to-deploy Model
|
||||
|
||||
- BlazeFace deployment model implementation comes from [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),and [Pre-training model based on WiderFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
|
||||
- (1)Provided in [Official library
|
||||
](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools) *.params, could deploy after operation [export_model.py](#Export PADDLE model);
|
||||
- (2)Developers can train BlazeFace model based on their own data according to [export_model. py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)After exporting the model, complete the deployment。
|
||||
|
||||
## Export PADDLE model
|
||||
|
||||
Visit [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) Github library, download and install according to the instructions, download the `. yml` and `. params` model parameters, and use` export_ Model. py `gets the` pad `model file`. yml,. pdiparams,. pdmodel `.
|
||||
|
||||
|
||||
* Download BlazeFace model parameter file
|
||||
|
||||
|Network structure | input size | number of pictures/GPU | learning rate strategy | Easy/Media/Hard Set | prediction delay (SD855) | model size (MB) | download | configuration file|
|
||||
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
|
||||
| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
|
||||
| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[Download link](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [Config file](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
|
||||
|
||||
* Export paddle-format file
|
||||
```bash
|
||||
python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
|
||||
```
|
||||
|
||||
## Detailed Deployment Tutorials
|
||||
|
||||
- [Python Deployment](python)
|
||||
- [C++ Deployment](cpp)
|
||||
|
||||
|
||||
## Release Note
|
||||
|
||||
- This tutorial and related code are written based on [BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
|
31
examples/vision/facedet/blazeface/README_CN.md
Normal file
31
examples/vision/facedet/blazeface/README_CN.md
Normal file
@@ -0,0 +1,31 @@
|
||||
# BlazeFace准备部署模型
|
||||
|
||||
- BlazeFace部署模型实现来自[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection),和[基于WiderFace的预训练模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)
|
||||
- (1)[官方库](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/tools)中提供的*.params,通过[export_model.py](#导出PADDLE模型)操作后,可进行部署;
|
||||
- (2)开发者基于自己数据训练的BlazeFace模型,可按照[export_model.py](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.5/tools/export_model.py)导出模型后,完成部署。
|
||||
|
||||
## 导出PADDLE模型
|
||||
|
||||
访问[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection)github库,按照指引下载安装,下载`.yml`和`.params` 模型参数,利用 `export_model.py` 得到`paddle`模型文件`.yml, .pdiparams, .pdmodel`。
|
||||
|
||||
* 下载BlazeFace模型参数文件
|
||||
|
||||
| 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延(SD855)| 模型大小(MB) | 下载 | 配置文件 |
|
||||
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
|
||||
| BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_1000e.yml) |
|
||||
| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection/blazeface_fpn_ssh_1000e.yml) |
|
||||
|
||||
* 导出paddle格式文件
|
||||
```bash
|
||||
python tools/export_model.py -c configs/face_detection/blazeface_1000e.yml -o weights=blazeface_1000e.pdparams --export_serving_model=True
|
||||
```
|
||||
|
||||
## 详细部署文档
|
||||
|
||||
- [Python部署](python)
|
||||
- [C++部署](cpp)
|
||||
|
||||
|
||||
## 版本说明
|
||||
|
||||
- 本版本文档和代码基于[BlazeFace](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/face_detection) 编写
|
14
examples/vision/facedet/blazeface/cpp/CMakeLists.txt
Normal file
14
examples/vision/facedet/blazeface/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
PROJECT(infer_demo C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||
|
||||
# Specifies the path to the fastdeploy library after you have downloaded it
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
|
||||
include(../../../../../FastDeploy.cmake)
|
||||
|
||||
# Add the FastDeploy dependency header
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
|
||||
# Add the FastDeploy library dependency
|
||||
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
|
78
examples/vision/facedet/blazeface/cpp/README.md
Normal file
78
examples/vision/facedet/blazeface/cpp/README.md
Normal file
@@ -0,0 +1,78 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
# BlazeFace C++ Deployment Example
|
||||
|
||||
This directory provides examples that `infer.cc` fast finishes the deployment of BlazeFace on CPU/GPU。
|
||||
|
||||
Before deployment, two steps require confirmation
|
||||
|
||||
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
|
||||
- 2. Download the precompiled deployment library and samples code according to your development environment. Refer to [FastDeploy Precompiled Library](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
Taking the CPU inference on Linux as an example, the compilation test can be completed by executing the following command in this directory.
|
||||
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
# Download the FastDeploy precompiled library. Users can choose your appropriate version in the `FastDeploy Precompiled Library` mentioned above
|
||||
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
|
||||
tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
|
||||
make -j
|
||||
|
||||
#Download the official converted YOLOv7Face model files and test images
|
||||
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
|
||||
|
||||
#Use blazeface-1000e model
|
||||
# CPU inference
|
||||
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
|
||||
# GPU Inference
|
||||
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
|
||||
```
|
||||
|
||||
The visualized result after running is as follows
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/49013063/206170111-843febb6-67d6-4c46-a121-d87d003bba21.jpg">
|
||||
|
||||
The above command works for Linux or MacOS. For SDK use-pattern in Windows, refer to:
|
||||
- [How to use FastDeploy C++ SDK in Windows](../../../../../docs/cn/faq/use_sdk_on_windows.md)
|
||||
|
||||
## BlazeFace C++ Interface
|
||||
|
||||
### BlazeFace Class
|
||||
|
||||
```c++
|
||||
fastdeploy::vision::facedet::BlazeFace(
|
||||
const string& model_file,
|
||||
const string& params_file = "",
|
||||
const string& config_file = "",
|
||||
const RuntimeOption& runtime_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
```
|
||||
|
||||
BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
|
||||
|
||||
**Parameter**
|
||||
|
||||
> * **model_file**(str): Model file path
|
||||
> * **params_file**(str): Parameter file path. Only passing an empty string when the model is in PADDLE format
|
||||
> * **config_file**(str): Config file path. Only passing an empty string when the model is in PADDLE format
|
||||
> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
|
||||
> * **model_format**(ModelFormat): Model format. PADDLE format by default
|
||||
|
||||
#### Predict Function
|
||||
|
||||
> ```c++
|
||||
> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
|
||||
> ```
|
||||
>
|
||||
> Model prediction interface. Input images and output detection results.
|
||||
>
|
||||
> **Parameter**
|
||||
>
|
||||
> > * **im**: Input images in HWC or BGR format
|
||||
> > * **result**: Detection results, including detection box and confidence of each box. Refer to [Vision Model Prediction Result](../../../../../docs/api/vision_results/) for FaceDetectionResult
|
||||
|
||||
- [Model Description](../../)
|
||||
- [Python Deployment](../python)
|
||||
- [Vision Model Prediction Results](../../../../../docs/api/vision_results/)
|
77
examples/vision/facedet/blazeface/cpp/README_CN.md
Normal file
77
examples/vision/facedet/blazeface/cpp/README_CN.md
Normal file
@@ -0,0 +1,77 @@
|
||||
[English](README.md) | 简体中文
|
||||
# BlazeFace C++部署示例
|
||||
|
||||
本目录下提供`infer.cc`快速完成BlazeFace在CPU/GPU部署的示例。
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试
|
||||
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
# 下载FastDeploy预编译库,用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
|
||||
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
|
||||
tar xvf fastdeploy-linux-x64-x.x.x.tgz # x.x.x >= 1.0.4
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x # x.x.x >= 1.0.4
|
||||
make -j
|
||||
|
||||
#下载官方转换好的BlazeFace模型文件和测试图片
|
||||
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/blzeface-1000e.tgz
|
||||
|
||||
#使用blazeface-1000e模型
|
||||
# CPU推理
|
||||
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 0
|
||||
# GPU推理
|
||||
./infer_demo blazeface-1000e/ test_lite_face_detector_3.jpg 1
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/49013063/206170111-843febb6-67d6-4c46-a121-d87d003bba21.jpg">
|
||||
|
||||
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
|
||||
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
|
||||
|
||||
## BlazeFace C++接口
|
||||
|
||||
### BlazeFace类
|
||||
|
||||
```c++
|
||||
fastdeploy::vision::facedet::BlazeFace(
|
||||
const string& model_file,
|
||||
const string& params_file = "",
|
||||
const string& config_file = "",
|
||||
const RuntimeOption& runtime_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE)
|
||||
```
|
||||
|
||||
BlazeFace模型加载和初始化,其中model_file为导出的PADDLE模型格式。
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **config_file**(str): 配置文件路径,当模型格式为ONNX时,此参数传入空字符串即可
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(ModelFormat): 模型格式,默认为PADDLE格式
|
||||
|
||||
#### Predict函数
|
||||
|
||||
> ```c++
|
||||
> BlazeFace::Predict(cv::Mat& im, FaceDetectionResult* result)
|
||||
> ```
|
||||
>
|
||||
> 模型预测接口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **im**: 输入图像,注意需为HWC,BGR格式
|
||||
> > * **result**: 检测结果,包括检测框,各个框的置信度, FaceDetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
- [模型介绍](../../)
|
||||
- [Python部署](../python)
|
||||
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
|
94
examples/vision/facedet/blazeface/cpp/infer.cc
Normal file
94
examples/vision/facedet/blazeface/cpp/infer.cc
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision.h"
|
||||
|
||||
#ifdef WIN32
|
||||
const char sep = '\\';
|
||||
#else
|
||||
const char sep = '/';
|
||||
#endif
|
||||
|
||||
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseCpu();
|
||||
auto model = fastdeploy::vision::facedet::BlazeFace(
|
||||
model_file, params_file, config_file, option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::FaceDetectionResult res;
|
||||
if (!model.Predict(im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
|
||||
auto model_file = model_dir + sep + "model.pdmodel";
|
||||
auto params_file = model_dir + sep + "model.pdiparams";
|
||||
auto config_file = model_dir + sep + "infer_cfg.yml";
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
option.UseGpu();
|
||||
auto model = fastdeploy::vision::facedet::BlazeFace(
|
||||
model_file, params_file, config_file, option);
|
||||
if (!model.Initialized()) {
|
||||
std::cerr << "Failed to initialize." << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
auto im = cv::imread(image_file);
|
||||
|
||||
fastdeploy::vision::FaceDetectionResult res;
|
||||
if (!model.Predict(im, &res)) {
|
||||
std::cerr << "Failed to predict." << std::endl;
|
||||
return;
|
||||
}
|
||||
std::cout << res.Str() << std::endl;
|
||||
|
||||
auto vis_im = fastdeploy::vision::VisFaceDetection(im, res);
|
||||
cv::imwrite("vis_result.jpg", vis_im);
|
||||
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 4) {
|
||||
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
|
||||
"e.g ./infer_model yolov5s-face.onnx ./test.jpeg 0"
|
||||
<< std::endl;
|
||||
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
|
||||
"with gpu; 2: run with gpu and use tensorrt backend."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[3]) == 0) {
|
||||
CpuInfer(argv[1], argv[2]);
|
||||
} else if (std::atoi(argv[3]) == 1) {
|
||||
GpuInfer(argv[1], argv[2]);
|
||||
}
|
||||
return 0;
|
||||
}
|
68
examples/vision/facedet/blazeface/python/README.md
Normal file
68
examples/vision/facedet/blazeface/python/README.md
Normal file
@@ -0,0 +1,68 @@
|
||||
English | [简体中文](README_CN.md)
|
||||
# BlazeFace Python Deployment Example
|
||||
|
||||
Before deployment, two steps require confirmation
|
||||
|
||||
- 1. Software and hardware should meet the requirements. Please refer to [FastDeploy Environment Requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
|
||||
- 2. Install FastDeploy Python whl package. Refer to [FastDeploy Python Installation](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
This directory provides examples that `infer.py` fast finishes the deployment of BlazeFace on CPU/GPU.
|
||||
|
||||
```bash
|
||||
# Download the example code for deployment
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd examples/vision/facedet/blazeface/python/
|
||||
|
||||
# Download BlazeFace model files and test images
|
||||
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
|
||||
|
||||
# Use blazeface-1000e model
|
||||
# CPU Inference
|
||||
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
|
||||
# GPU Inference
|
||||
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
|
||||
```
|
||||
|
||||
The visualized result after running is as follows
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/67993288/184301839-a29aefae-16c9-4196-bf9d-9c6cf694f02d.jpg">
|
||||
|
||||
## BlazeFace Python Interface
|
||||
|
||||
```python
|
||||
fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
|
||||
```
|
||||
|
||||
BlazeFace model loading and initialization, among which model_file is the exported PADDLE model format
|
||||
|
||||
**Parameter**
|
||||
|
||||
> * **model_file**(str): Model file path
|
||||
> * **params_file**(str): Parameter file path. No need to set when the model is in PADDLE format
|
||||
> * **config_file**(str): config file path. No need to set when the model is in PADDLE format
|
||||
> * **runtime_option**(RuntimeOption): Backend inference configuration. None by default, which is the default configuration
|
||||
> * **model_format**(ModelFormat): Model format. PADDLE format by default
|
||||
|
||||
### predict function
|
||||
|
||||
> ```python
|
||||
> BlazeFace.predict(input_image)
|
||||
> ```
|
||||
> Through let BlazeFace.postprocessor.conf_threshold = 0.2,to modify conf_threshold
|
||||
>
|
||||
> Model prediction interface. Input images and output detection results.
|
||||
>
|
||||
> **Parameter**
|
||||
>
|
||||
> > * **input_image**(np.ndarray): Input image in HWC or BGR format
|
||||
|
||||
> **Return**
|
||||
>
|
||||
> > Return`fastdeploy.vision.FaceDetectionResult` structure. Refer to [Vision Model Prediction Results](../../../../../docs/api/vision_results/) for its description.
|
||||
|
||||
## Other Documents
|
||||
|
||||
- [BlazeFace Model Description](..)
|
||||
- [BlazeFace C++ Deployment](../cpp)
|
||||
- [Model Prediction Results](../../../../../docs/api/vision_results/)
|
68
examples/vision/facedet/blazeface/python/README_CN.md
Normal file
68
examples/vision/facedet/blazeface/python/README_CN.md
Normal file
@@ -0,0 +1,68 @@
|
||||
[English](README.md) | 简体中文
|
||||
# BlazeFace Python部署示例
|
||||
|
||||
在部署前,需确认以下两个步骤
|
||||
|
||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
本目录下提供`infer.py`快速完成BlazeFace在CPU/GPU部署的示例。执行如下脚本即可完成
|
||||
|
||||
```bash
|
||||
#下载部署示例代码
|
||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||
cd examples/vision/facedet/blazeface/python/
|
||||
|
||||
#下载BlazeFace模型文件和测试图片
|
||||
wget https://raw.githubusercontent.com/DefTruth/lite.ai.toolkit/main/examples/lite/resources/test_lite_face_detector_3.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/blazeface-1000e.tgz
|
||||
|
||||
#使用blazeface-1000e模型
|
||||
# CPU推理
|
||||
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device cpu
|
||||
# GPU推理
|
||||
python infer.py --model blazeface-1000e/ --image test_lite_face_detector_3.jpg --device gpu
|
||||
```
|
||||
|
||||
运行完成可视化结果如下图所示
|
||||
|
||||
<img width="640" src="https://user-images.githubusercontent.com/67993288/184301839-a29aefae-16c9-4196-bf9d-9c6cf694f02d.jpg">
|
||||
|
||||
## BlazeFace Python接口
|
||||
|
||||
```python
|
||||
fastdeploy.vision.facedet.BlzaeFace(model_file, params_file=None, runtime_option=None, config_file=None, model_format=ModelFormat.PADDLE)
|
||||
```
|
||||
|
||||
BlazeFace模型加载和初始化
|
||||
|
||||
**参数**
|
||||
|
||||
> * **model_file**(str): 模型文件路径
|
||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
|
||||
> * **config_file**(str): config文件路径,当模型格式为ONNX格式时,此参数无需设定
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(ModelFormat): 模型格式,默认为PADDLE
|
||||
|
||||
### predict函数
|
||||
|
||||
> ```python
|
||||
> BlazeFace.predict(input_image)
|
||||
> ```
|
||||
> 通过BlazeFace.postprocessor.conf_threshold = 0.2,来修改conf_threshold
|
||||
>
|
||||
> 模型预测结口,输入图像直接输出检测结果。
|
||||
>
|
||||
> **参数**
|
||||
>
|
||||
> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
||||
|
||||
> **返回**
|
||||
>
|
||||
> > 返回`fastdeploy.vision.FaceDetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||
|
||||
## 其它文档
|
||||
|
||||
- [BlazeFace 模型介绍](..)
|
||||
- [BlazeFace C++部署](../cpp)
|
||||
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
58
examples/vision/facedet/blazeface/python/infer.py
Normal file
58
examples/vision/facedet/blazeface/python/infer.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model", required=True, help="Path of blazeface model dir.")
|
||||
parser.add_argument(
|
||||
"--image", required=True, help="Path of test image file.")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
|
||||
if args.use_trt:
|
||||
option.use_trt_backend()
|
||||
option.set_trt_input_shape("images", [1, 3, 640, 640])
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
model_dir = args.model
|
||||
|
||||
model_file = os.path.join(model_dir, "model.pdmodel")
|
||||
params_file = os.path.join(model_dir, "model.pdiparams")
|
||||
config_file = os.path.join(model_dir, "infer_cfg.yml")
|
||||
|
||||
# Configure runtime and load the model
|
||||
runtime_option = build_option(args)
|
||||
model = fd.vision.facedet.BlazeFace(model_file, params_file, config_file, runtime_option=runtime_option)
|
||||
|
||||
# Predict image detection results
|
||||
im = cv2.imread(args.image)
|
||||
result = model.predict(im)
|
||||
print(result)
|
||||
# Visualization of prediction Results
|
||||
vis_im = fd.vision.vis_face_detection(im, result)
|
||||
cv2.imwrite("visualized_result.jpg", vis_im)
|
||||
print("Visualized result save in ./visualized_result.jpg")
|
@@ -41,6 +41,7 @@
|
||||
#include "fastdeploy/vision/facedet/contrib/ultraface.h"
|
||||
#include "fastdeploy/vision/facedet/contrib/yolov5face.h"
|
||||
#include "fastdeploy/vision/facedet/contrib/yolov7face/yolov7face.h"
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
|
||||
#include "fastdeploy/vision/faceid/contrib/insightface/model.h"
|
||||
#include "fastdeploy/vision/faceid/contrib/adaface/adaface.h"
|
||||
#include "fastdeploy/vision/headpose/contrib/fsanet.h"
|
||||
|
@@ -20,6 +20,7 @@ void BindRetinaFace(pybind11::module& m);
|
||||
void BindUltraFace(pybind11::module& m);
|
||||
void BindYOLOv5Face(pybind11::module& m);
|
||||
void BindYOLOv7Face(pybind11::module& m);
|
||||
void BindBlazeFace(pybind11::module& m);
|
||||
void BindSCRFD(pybind11::module& m);
|
||||
|
||||
void BindFaceDet(pybind11::module& m) {
|
||||
@@ -28,6 +29,7 @@ void BindFaceDet(pybind11::module& m) {
|
||||
BindUltraFace(facedet_module);
|
||||
BindYOLOv5Face(facedet_module);
|
||||
BindYOLOv7Face(facedet_module);
|
||||
BindBlazeFace(facedet_module);
|
||||
BindSCRFD(facedet_module);
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
|
93
fastdeploy/vision/facedet/ppdet/blazeface/blazeface.cc
Normal file
93
fastdeploy/vision/facedet/ppdet/blazeface/blazeface.cc
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h"
|
||||
#include "fastdeploy/utils/perf.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy{
|
||||
|
||||
namespace vision{
|
||||
|
||||
namespace facedet{
|
||||
|
||||
BlazeFace::BlazeFace(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format)
|
||||
: preprocessor_(config_file){
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::LITE};
|
||||
valid_gpu_backends = {Backend::OPENVINO, Backend::LITE, Backend::PDINFER};
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool BlazeFace::Initialize(){
|
||||
if (!InitRuntime()){
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BlazeFace::Predict(const cv::Mat& im, FaceDetectionResult* result){
|
||||
std::vector<FaceDetectionResult> results;
|
||||
if (!this->BatchPredict({im}, &results)) {
|
||||
return false;
|
||||
}
|
||||
*result = std::move(results[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BlazeFace::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
std::vector<FaceDetectionResult>* results){
|
||||
std::vector<FDMat> fd_images = WrapMat(images);
|
||||
FDASSERT(images.size() == 1, "Only support batch = 1 now.");
|
||||
std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_, &ims_info)) {
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
reused_input_tensors_[0].name = "image";
|
||||
reused_input_tensors_[1].name = "scale_factor";
|
||||
reused_input_tensors_[2].name = "im_shape";
|
||||
|
||||
// Some models don't need scale_factor and im_shape as input
|
||||
while (reused_input_tensors_.size() != NumInputsOfRuntime()) {
|
||||
reused_input_tensors_.pop_back();
|
||||
}
|
||||
|
||||
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results, ims_info)){
|
||||
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace facedet
|
||||
|
||||
} // namespace vision
|
||||
|
||||
} // namespace fastdeploy
|
83
fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h
Normal file
83
fastdeploy/vision/facedet/ppdet/blazeface/blazeface.h
Normal file
@@ -0,0 +1,83 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace vision {
|
||||
|
||||
namespace facedet {
|
||||
/*! @brief BlazeFace model object used when to load a BlazeFace model exported by BlazeFace.
|
||||
*/
|
||||
class FASTDEPLOY_DECL BlazeFace: public FastDeployModel{
|
||||
public:
|
||||
/** \brief Set path of model file and the configuration of runtime.
|
||||
*
|
||||
* \param[in] model_file Path of model file, e.g ./blazeface.onnx
|
||||
* \param[in] params_file Path of parameter file, e.g ppyoloe/model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g resnet/infer_cfg.yml
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
|
||||
* \param[in] model_format Model format of the loaded model, default is ONNX format
|
||||
*/
|
||||
BlazeFace(const std::string& model_file, const std::string& params_file = "",
|
||||
const std::string& config_file = "",
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
std::string ModelName() {return "blaze-face";}
|
||||
|
||||
/** \brief Predict the detection result for an input image
|
||||
*
|
||||
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
* \param[in] result The output detection result will be writen to this structure
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
bool Predict(const cv::Mat& im, FaceDetectionResult* result);
|
||||
|
||||
/** \brief Predict the detection results for a batch of input images
|
||||
*
|
||||
* \param[in] imgs, The input image list, each element comes from cv::imread()
|
||||
* \param[in] results The output detection result list
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
|
||||
std::vector<FaceDetectionResult>* results);
|
||||
|
||||
/// Get preprocessor reference of BlazeFace
|
||||
virtual BlazeFacePreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
/// Get postprocessor reference of BlazeFace
|
||||
virtual BlazeFacePostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool Initialize();
|
||||
BlazeFacePreprocessor preprocessor_;
|
||||
BlazeFacePostprocessor postprocessor_;
|
||||
};
|
||||
|
||||
} // namespace facedet
|
||||
|
||||
} // namespace vision
|
||||
|
||||
} // namespace fastdeploy
|
@@ -0,0 +1,84 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindBlazeFace(pybind11::module& m) {
|
||||
pybind11::class_<vision::facedet::BlazeFacePreprocessor>(
|
||||
m, "BlazeFacePreprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def("run", [](vision::facedet::BlazeFacePreprocessor& self, std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
std::vector<std::map<std::string, std::array<float, 2>>> ims_info;
|
||||
if (!self.Run(&images, &outputs, &ims_info)) {
|
||||
throw std::runtime_error("Failed to preprocess the input data in BlazeFacePreprocessor.");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
return make_pair(outputs, ims_info);
|
||||
});
|
||||
|
||||
pybind11::class_<vision::facedet::BlazeFacePostprocessor>(
|
||||
m, "BlazeFacePostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector<FDTensor>& inputs,
|
||||
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
|
||||
std::vector<vision::FaceDetectionResult> results;
|
||||
if (!self.Run(inputs, &results, ims_info)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in BlazeFacePostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("run", [](vision::facedet::BlazeFacePostprocessor& self, std::vector<pybind11::array>& input_array,
|
||||
const std::vector<std::map<std::string, std::array<float, 2>>>& ims_info) {
|
||||
std::vector<vision::FaceDetectionResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results, ims_info)) {
|
||||
throw std::runtime_error("Failed to postprocess the runtime result in BlazePostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def_property("conf_threshold", &vision::facedet::BlazeFacePostprocessor::GetConfThreshold, &vision::facedet::BlazeFacePostprocessor::SetConfThreshold)
|
||||
.def_property("nms_threshold", &vision::facedet::BlazeFacePostprocessor::GetNMSThreshold, &vision::facedet::BlazeFacePostprocessor::SetNMSThreshold);
|
||||
|
||||
pybind11::class_<vision::facedet::BlazeFace, FastDeployModel>(m, "BlazeFace")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::facedet::BlazeFace& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::FaceDetectionResult res;
|
||||
self.Predict(mat, &res);
|
||||
return res;
|
||||
})
|
||||
.def("batch_predict", [](vision::facedet::BlazeFace& self, std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
}
|
||||
std::vector<vision::FaceDetectionResult> results;
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("preprocessor", &vision::facedet::BlazeFace::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor", &vision::facedet::BlazeFace::GetPostprocessor);
|
||||
}
|
||||
} // namespace fastdeploy
|
96
fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.cc
Normal file
96
fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.cc
Normal file
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h"
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace vision {
|
||||
|
||||
namespace facedet {
|
||||
|
||||
BlazeFacePostprocessor::BlazeFacePostprocessor() {
|
||||
conf_threshold_ = 0.5;
|
||||
nms_threshold_ = 0.3;
|
||||
}
|
||||
|
||||
bool BlazeFacePostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<FaceDetectionResult>* results,
|
||||
const std::vector<std::map<std::string,
|
||||
std::array<float, 2>>>& ims_info) {
|
||||
// Get number of boxes for each input image
|
||||
std::vector<int> num_boxes(tensors[1].shape[0]);
|
||||
int total_num_boxes = 0;
|
||||
if (tensors[1].dtype == FDDataType::INT32) {
|
||||
const auto* data = static_cast<const int32_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
total_num_boxes += num_boxes[i];
|
||||
}
|
||||
} else if (tensors[1].dtype == FDDataType::INT64) {
|
||||
const auto* data = static_cast<const int64_t*>(tensors[1].CpuData());
|
||||
for (size_t i = 0; i < tensors[1].shape[0]; ++i) {
|
||||
num_boxes[i] = static_cast<int>(data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Special case for TensorRT, it has fixed output shape of NMS
|
||||
// So there's invalid boxes in its' output boxes
|
||||
int num_output_boxes = static_cast<int>(tensors[0].Shape()[0]);
|
||||
bool contain_invalid_boxes = false;
|
||||
if (total_num_boxes != num_output_boxes) {
|
||||
if (num_output_boxes % num_boxes.size() == 0) {
|
||||
contain_invalid_boxes = true;
|
||||
} else {
|
||||
FDERROR << "Cannot handle the output data for this model, unexpected "
|
||||
"situation."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Get boxes for each input image
|
||||
results->resize(num_boxes.size());
|
||||
|
||||
if (tensors[0].shape[0] == 0) {
|
||||
// No detected boxes
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto* box_data = static_cast<const float*>(tensors[0].CpuData());
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < num_boxes.size(); ++i) {
|
||||
const float* ptr = box_data + offset;
|
||||
(*results)[i].Reserve(num_boxes[i]);
|
||||
for (size_t j = 0; j < num_boxes[i]; ++j) {
|
||||
if (ptr[j * 6 + 1] > conf_threshold_) {
|
||||
(*results)[i].scores.push_back(ptr[j * 6 + 1]);
|
||||
(*results)[i].boxes.emplace_back(std::array<float, 4>(
|
||||
{ptr[j * 6 + 2], ptr[j * 6 + 3], ptr[j * 6 + 4], ptr[j * 6 + 5]}));
|
||||
}
|
||||
}
|
||||
if (contain_invalid_boxes) {
|
||||
offset += static_cast<int>(num_output_boxes * 6 / num_boxes.size());
|
||||
} else {
|
||||
offset += static_cast<int>(num_boxes[i] * 6);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detection
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
66
fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h
Normal file
66
fastdeploy/vision/facedet/ppdet/blazeface/postprocessor.h
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace vision {
|
||||
|
||||
namespace facedet {
|
||||
|
||||
class FASTDEPLOY_DECL BlazeFacePostprocessor{
|
||||
public:
|
||||
/*! @brief Postprocessor object for BlazeFace serials model.
|
||||
*/
|
||||
BlazeFacePostprocessor();
|
||||
|
||||
/** \brief Process the result of runtime and fill to FaceDetectionResult structure
|
||||
*
|
||||
* \param[in] infer_result The inference result from runtime
|
||||
* \param[in] results The output result of detection
|
||||
* \param[in] ims_info The shape info list, record input_shape and output_shape
|
||||
* \return true if the postprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(const std::vector<FDTensor>& infer_result,
|
||||
std::vector<FaceDetectionResult>* results,
|
||||
const std::vector<std::map<std::string,
|
||||
std::array<float, 2>>>& ims_info);
|
||||
|
||||
/// Set conf_threshold, default 0.5
|
||||
void SetConfThreshold(const float& conf_threshold) {
|
||||
conf_threshold_ = conf_threshold;
|
||||
}
|
||||
|
||||
/// Get conf_threshold, default 0.5
|
||||
float GetConfThreshold() const { return conf_threshold_; }
|
||||
|
||||
/// Set nms_threshold, default 0.3
|
||||
void SetNMSThreshold(const float& nms_threshold) {
|
||||
nms_threshold_ = nms_threshold;
|
||||
}
|
||||
|
||||
/// Get nms_threshold, default 0.3
|
||||
float GetNMSThreshold() const { return nms_threshold_; }
|
||||
|
||||
protected:
|
||||
float conf_threshold_;
|
||||
float nms_threshold_;
|
||||
};
|
||||
|
||||
} // namespace facedet
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
207
fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.cc
Normal file
207
fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.cc
Normal file
@@ -0,0 +1,207 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h"
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "fastdeploy/function/pad.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace vision {
|
||||
|
||||
namespace facedet {
|
||||
|
||||
BlazeFacePreprocessor::BlazeFacePreprocessor(const std::string& config_file) {
|
||||
is_scale_ = false;
|
||||
normalize_mean_ = {123, 117, 104};
|
||||
normalize_std_ = {127.502231, 127.502231, 127.502231};
|
||||
this->config_file_ = config_file;
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(),
|
||||
"Failed to create PaddleDetPreprocessor.");
|
||||
}
|
||||
|
||||
bool BlazeFacePreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
|
||||
std::vector<std::map<std::string, std::array<float, 2>>>* ims_info) {
|
||||
if (images->size() == 0) {
|
||||
FDERROR << "The size of input images should be greater than 0." << std::endl;
|
||||
return false;
|
||||
}
|
||||
ims_info->resize(images->size());
|
||||
outputs->resize(3);
|
||||
int batch = static_cast<int>(images->size());
|
||||
// Allocate memory for scale_factor
|
||||
(*outputs)[1].Resize({batch, 2}, FDDataType::FP32);
|
||||
// Allocate memory for im_shape
|
||||
(*outputs)[2].Resize({batch, 2}, FDDataType::FP32);
|
||||
|
||||
std::vector<int> max_hw({-1, -1});
|
||||
|
||||
auto* scale_factor_ptr =
|
||||
reinterpret_cast<float*>((*outputs)[1].MutableData());
|
||||
auto* im_shape_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
|
||||
// Concat all the preprocessed data to a batch tensor
|
||||
std::vector<FDTensor> im_tensors(images->size());
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
int origin_w = (*images)[i].Width();
|
||||
int origin_h = (*images)[i].Height();
|
||||
scale_factor_ptr[2 * i] = 1.0;
|
||||
scale_factor_ptr[2 * i + 1] = 1.0;
|
||||
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
if (!(*(processors_[j].get()))(&((*images)[i]))) {
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[i]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
if (processors_[j]->Name().find("Resize") != std::string::npos) {
|
||||
scale_factor_ptr[2 * i] = (*images)[i].Height() * 1.0 / origin_h;
|
||||
scale_factor_ptr[2 * i + 1] = (*images)[i].Width() * 1.0 / origin_w;
|
||||
}
|
||||
}
|
||||
|
||||
if ((*images)[i].Height() > max_hw[0]) {
|
||||
max_hw[0] = (*images)[i].Height();
|
||||
}
|
||||
if ((*images)[i].Width() > max_hw[1]) {
|
||||
max_hw[1] = (*images)[i].Width();
|
||||
}
|
||||
im_shape_ptr[2 * i] = max_hw[0];
|
||||
im_shape_ptr[2 * i + 1] = max_hw[1];
|
||||
|
||||
if ((*images)[i].Height() < max_hw[0] || (*images)[i].Width() < max_hw[1]) {
|
||||
// if the size of image less than max_hw, pad to max_hw
|
||||
FDTensor tensor;
|
||||
(*images)[i].ShareWithTensor(&tensor);
|
||||
function::Pad(tensor, &(im_tensors[i]),
|
||||
{0, 0, max_hw[0] - (*images)[i].Height(),
|
||||
max_hw[1] - (*images)[i].Width()},
|
||||
0);
|
||||
} else {
|
||||
// No need pad
|
||||
(*images)[i].ShareWithTensor(&(im_tensors[i]));
|
||||
}
|
||||
// Reshape to 1xCxHxW
|
||||
im_tensors[i].ExpandDim(0);
|
||||
}
|
||||
|
||||
if (im_tensors.size() == 1) {
|
||||
// If there's only 1 input, no need to concat
|
||||
// skip memory copy
|
||||
(*outputs)[0] = std::move(im_tensors[0]);
|
||||
} else {
|
||||
// Else concat the im tensor for each input image
|
||||
// compose a batched input tensor
|
||||
function::Concat(im_tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BlazeFacePreprocessor::BuildPreprocessPipelineFromConfig() {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||
|
||||
bool has_permute = false;
|
||||
for (const auto& op : cfg["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "NormalizeImage") {
|
||||
auto mean = op["mean"].as<std::vector<float>>();
|
||||
auto std = op["std"].as<std::vector<float>>();
|
||||
bool is_scale = true;
|
||||
if (op["is_scale"]) {
|
||||
is_scale = op["is_scale"].as<bool>();
|
||||
}
|
||||
std::string norm_type = "mean_std";
|
||||
if (op["norm_type"]) {
|
||||
norm_type = op["norm_type"].as<std::string>();
|
||||
}
|
||||
if (norm_type != "mean_std") {
|
||||
std::fill(mean.begin(), mean.end(), 0.0);
|
||||
std::fill(std.begin(), std.end(), 1.0);
|
||||
}
|
||||
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
|
||||
} else if (op_name == "Resize") {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
int width = target_size[1];
|
||||
int height = target_size[0];
|
||||
processors_.push_back(
|
||||
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
|
||||
} else {
|
||||
int min_target_size = std::min(target_size[0], target_size[1]);
|
||||
int max_target_size = std::max(target_size[0], target_size[1]);
|
||||
std::vector<int> max_size;
|
||||
if (max_target_size > 0) {
|
||||
max_size.push_back(max_target_size);
|
||||
max_size.push_back(max_target_size);
|
||||
}
|
||||
processors_.push_back(std::make_shared<ResizeByShort>(
|
||||
min_target_size, interp, true, max_size));
|
||||
}
|
||||
} else if (op_name == "Permute") {
|
||||
// Do nothing, do permute as the last operation
|
||||
has_permute = true;
|
||||
continue;
|
||||
} else if (op_name == "Pad") {
|
||||
auto size = op["size"].as<std::vector<int>>();
|
||||
auto value = op["fill_value"].as<std::vector<float>>();
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(
|
||||
std::make_shared<PadToSize>(size[1], size[0], value));
|
||||
} else if (op_name == "PadStride") {
|
||||
auto stride = op["stride"].as<int>();
|
||||
processors_.push_back(
|
||||
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
|
||||
} else {
|
||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_permute) {
|
||||
// permute = cast<float> + HWC2CHW
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace facedet
|
||||
|
||||
} // namespace vision
|
||||
|
||||
} // namespacefastdeploy
|
69
fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h
Normal file
69
fastdeploy/vision/facedet/ppdet/blazeface/preprocessor.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
#include "fastdeploy/vision/detection/ppdet/preprocessor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
namespace vision {
|
||||
|
||||
namespace facedet {
|
||||
|
||||
class FASTDEPLOY_DECL BlazeFacePreprocessor:
|
||||
public fastdeploy::vision::detection::PaddleDetPreprocessor {
|
||||
public:
|
||||
/** \brief Create a preprocessor instance for BlazeFace serials model
|
||||
*/
|
||||
BlazeFacePreprocessor() = default;
|
||||
|
||||
/** \brief Create a preprocessor instance for Blazeface serials model
|
||||
*
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
|
||||
*/
|
||||
explicit BlazeFacePreprocessor(const std::string& config_file);
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \param[in] ims_info The shape info list, record input_shape and output_shape
|
||||
* \ret
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs,
|
||||
std::vector<std::map<std::string, std::array<float, 2>>>* ims_info);
|
||||
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig();
|
||||
|
||||
// if is_scale_up is false, the input image only can be zoom out,
|
||||
// the maximum resize scale cannot exceed 1.0
|
||||
bool is_scale_;
|
||||
|
||||
std::vector<float> normalize_mean_;
|
||||
|
||||
std::vector<float> normalize_std_;
|
||||
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
// read config file
|
||||
std::string config_file_;
|
||||
};
|
||||
|
||||
} // namespace facedet
|
||||
|
||||
} // namespace vision
|
||||
|
||||
} // namespace fastdeploy
|
@@ -15,6 +15,7 @@
|
||||
from __future__ import absolute_import
|
||||
from .contrib.yolov5face import YOLOv5Face
|
||||
from .contrib.yolov7face import *
|
||||
from .contrib.blazeface import *
|
||||
from .contrib.retinaface import RetinaFace
|
||||
from .contrib.scrfd import SCRFD
|
||||
from .contrib.ultraface import UltraFace
|
||||
|
143
python/fastdeploy/vision/facedet/contrib/blazeface.py
Normal file
143
python/fastdeploy/vision/facedet/contrib/blazeface.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
from .... import FastDeployModel, ModelFormat
|
||||
from .... import c_lib_wrap as C
|
||||
|
||||
|
||||
class BlazeFacePreprocessor:
|
||||
def __init__(self):
|
||||
"""Create a preprocessor for BlazeFace
|
||||
"""
|
||||
self._preprocessor = C.vision.facedet.BlazeFacePreprocessor()
|
||||
|
||||
def run(self, input_ims):
|
||||
"""Preprocess input images for BlazeFace
|
||||
|
||||
:param: input_ims: (list of numpy.ndarray)The input image
|
||||
:return: list of FDTensor
|
||||
"""
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
@property
|
||||
def is_scale_(self):
|
||||
"""
|
||||
is_scale_ for preprocessing, the input image only can be zoom out, the maximum resize scale cannot exceed 1.0, default true
|
||||
"""
|
||||
return self._preprocessor.is_scale_
|
||||
|
||||
@is_scale_.setter
|
||||
def is_scale_(self, value):
|
||||
assert isinstance(
|
||||
value,
|
||||
bool), "The value to set `is_scale_` must be type of bool."
|
||||
self._preprocessor.is_scale_ = value
|
||||
|
||||
|
||||
class BlazeFacePostprocessor:
|
||||
def __init__(self):
|
||||
"""Create a postprocessor for BlazeFace
|
||||
"""
|
||||
self._postprocessor = C.vision.facedet.BlazeFacePostprocessor()
|
||||
|
||||
def run(self, runtime_results, ims_info):
|
||||
"""Postprocess the runtime results for BlazeFace
|
||||
|
||||
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
|
||||
:param: ims_info: (list of dict)Record input_shape and output_shape
|
||||
:return: list of DetectionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
|
||||
"""
|
||||
return self._postprocessor.run(runtime_results, ims_info)
|
||||
|
||||
@property
|
||||
def conf_threshold(self):
|
||||
"""
|
||||
confidence threshold for postprocessing, default is 0.5
|
||||
"""
|
||||
return self._postprocessor.conf_threshold
|
||||
|
||||
@property
|
||||
def nms_threshold(self):
|
||||
"""
|
||||
nms threshold for postprocessing, default is 0.3
|
||||
"""
|
||||
return self._postprocessor.nms_threshold
|
||||
|
||||
@conf_threshold.setter
|
||||
def conf_threshold(self, conf_threshold):
|
||||
assert isinstance(conf_threshold, float),\
|
||||
"The value to set `conf_threshold` must be type of float."
|
||||
self._postprocessor.conf_threshold = conf_threshold
|
||||
|
||||
@nms_threshold.setter
|
||||
def nms_threshold(self, nms_threshold):
|
||||
assert isinstance(nms_threshold, float),\
|
||||
"The value to set `nms_threshold` must be type of float."
|
||||
self._postprocessor.nms_threshold = nms_threshold
|
||||
|
||||
|
||||
class BlazeFace(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file="",
|
||||
config_file="",
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.PADDLE):
|
||||
"""Load a BlazeFace model exported by BlazeFace.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g ./Blazeface.onnx
|
||||
:param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
|
||||
"""
|
||||
super(BlazeFace, self).__init__(runtime_option)
|
||||
|
||||
self._model = C.vision.facedet.BlazeFace(
|
||||
model_file, params_file, config_file, self._runtime_option, model_format)
|
||||
|
||||
assert self.initialized, "BlazeFace initialize failed."
|
||||
|
||||
def predict(self, input_image):
|
||||
"""Detect the location and key points of human faces from an input image
|
||||
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
|
||||
:return: FaceDetectionResult
|
||||
"""
|
||||
return self._model.predict(input_image)
|
||||
|
||||
def batch_predict(self, images):
|
||||
"""Classify a batch of input image
|
||||
|
||||
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
|
||||
:return list of FaceDetectionResult
|
||||
"""
|
||||
|
||||
return self._model.batch_predict(images)
|
||||
|
||||
@property
|
||||
def preprocessor(self):
|
||||
"""Get BlazefacePreprocessor object of the loaded model
|
||||
|
||||
:return BlazefacePreprocessor
|
||||
"""
|
||||
return self._model.preprocessor
|
||||
|
||||
@property
|
||||
def postprocessor(self):
|
||||
"""Get BlazefacePostprocessor object of the loaded model
|
||||
|
||||
:return BlazefacePostprocessor
|
||||
"""
|
||||
return self._model.postprocessor
|
151
tests/models/test_blazeface.py
Normal file
151
tests/models/test_blazeface.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastdeploy import ModelFormat
|
||||
import fastdeploy as fd
|
||||
import cv2
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import runtime_config as rc
|
||||
|
||||
|
||||
def test_detection_blazeface():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
input_url2 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000570688.jpg"
|
||||
result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
|
||||
result_url2 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result2.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(input_url2, "resources")
|
||||
|
||||
|
||||
model_dir = "resources/blazeface_1000e"
|
||||
model_file = os.path.join(model_dir, "model.pdmodel")
|
||||
params_file = os.path.join(model_dir, "model.pdiparams")
|
||||
config_file = os.path.join(model_dir, "infer_cfg.yml")
|
||||
model = fd.vision.facedet.BlazeFace(
|
||||
model_file, params_file, config_file, runtime_option=rc.test_option)
|
||||
model.postprocessor.conf_threshold = 0.5
|
||||
|
||||
with open("resources/blazeface_result1.pkl", "rb") as f:
|
||||
expect1 = pickle.load(f)
|
||||
|
||||
with open("resources/blazeface_result2.pkl", "rb") as f:
|
||||
expect2 = pickle.load(f)
|
||||
|
||||
im1 = cv2.imread("./resources/000000014439.jpg")
|
||||
im2 = cv2.imread("./resources/000000570688.jpg")
|
||||
|
||||
for i in range(3):
|
||||
# test single predict
|
||||
result1 = model.predict(im1)
|
||||
result2 = model.predict(im2)
|
||||
|
||||
diff_boxes_1 = np.fabs(
|
||||
np.array(result1.boxes) - np.array(expect1["boxes"]))
|
||||
diff_boxes_2 = np.fabs(
|
||||
np.array(result2.boxes) - np.array(expect2["boxes"]))
|
||||
|
||||
diff_scores_1 = np.fabs(
|
||||
np.array(result1.scores) - np.array(expect1["scores"]))
|
||||
diff_scores_2 = np.fabs(
|
||||
np.array(result2.scores) - np.array(expect2["scores"]))
|
||||
|
||||
assert diff_boxes_1.max(
|
||||
) < 1e-04, "There's difference in detection boxes 1."
|
||||
assert diff_scores_1.max(
|
||||
) < 1e-04, "There's difference in detection score 1."
|
||||
|
||||
assert diff_boxes_2.max(
|
||||
) < 1e-03, "There's difference in detection boxes 2."
|
||||
assert diff_scores_2.max(
|
||||
) < 1e-04, "There's difference in detection score 2."
|
||||
|
||||
print("one image test success!")
|
||||
|
||||
# test batch predict
|
||||
results = model.batch_predict([im1, im2])
|
||||
result1 = results[0]
|
||||
result2 = results[1]
|
||||
|
||||
diff_boxes_1 = np.fabs(
|
||||
np.array(result1.boxes) - np.array(expect1["boxes"]))
|
||||
diff_boxes_2 = np.fabs(
|
||||
np.array(result2.boxes) - np.array(expect2["boxes"]))
|
||||
|
||||
diff_scores_1 = np.fabs(
|
||||
np.array(result1.scores) - np.array(expect1["scores"]))
|
||||
diff_scores_2 = np.fabs(
|
||||
np.array(result2.scores) - np.array(expect2["scores"]))
|
||||
assert diff_boxes_1.max(
|
||||
) < 1e-04, "There's difference in detection boxes 1."
|
||||
assert diff_scores_1.max(
|
||||
) < 1e-03, "There's difference in detection score 1."
|
||||
|
||||
assert diff_boxes_2.max(
|
||||
) < 1e-04, "There's difference in detection boxes 2."
|
||||
assert diff_scores_2.max(
|
||||
) < 1e-04, "There's difference in detection score 2."
|
||||
|
||||
print("batch predict success!")
|
||||
|
||||
|
||||
def test_detection_blazeface_runtime():
|
||||
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_1000e.tgz"
|
||||
input_url1 = "https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg"
|
||||
result_url1 = "https://bj.bcebos.com/paddlehub/fastdeploy/blazeface_result1.pkl"
|
||||
fd.download_and_decompress(model_url, "resources")
|
||||
fd.download(input_url1, "resources")
|
||||
fd.download(result_url1, "resources")
|
||||
|
||||
model_dir = "resources/blazeface_1000e"
|
||||
model_file = os.path.join(model_dir, "model.pdmodel")
|
||||
params_file = os.path.join(model_dir, "model.pdiparams")
|
||||
config_file = os.path.join(model_dir, "infer_cfg.yml")
|
||||
|
||||
preprocessor = fd.vision.facedet.BlazeFacePreprocessor()
|
||||
postprocessor = fd.vision.facedet.BlazeFacePostprocessor()
|
||||
|
||||
rc.test_option.set_model_path(model_file, params_file, config_file, model_format=ModelFormat.PADDLE)
|
||||
rc.test_option.use_openvino_backend()
|
||||
runtime = fd.Runtime(rc.test_option)
|
||||
|
||||
with open("resources/blazeface_result1.pkl", "rb") as f:
|
||||
expect1 = pickle.load(f)
|
||||
|
||||
im1 = cv2.imread("resources/000000014439.jpg")
|
||||
|
||||
for i in range(3):
|
||||
# test runtime
|
||||
input_tensors, ims_info = preprocessor.run([im1.copy()])
|
||||
output_tensors = runtime.infer({"images": input_tensors[0]})
|
||||
results = postprocessor.run(output_tensors, ims_info)
|
||||
result1 = results[0]
|
||||
|
||||
diff_boxes_1 = np.fabs(
|
||||
np.array(result1.boxes) - np.array(expect1["boxes"]))
|
||||
diff_scores_1 = np.fabs(
|
||||
np.array(result1.scores) - np.array(expect1["scores"]))
|
||||
|
||||
assert diff_boxes_1.max(
|
||||
) < 1e-03, "There's difference in detection boxes 1."
|
||||
assert diff_scores_1.max(
|
||||
) < 1e-04, "There's difference in detection score 1."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detection_blazeface()
|
||||
test_detection_blaze_runtime()
|
Reference in New Issue
Block a user