[vision] Add AdaFace model support (#301)

* 新增adaface模型

* 新增adaface模型python代码

* 新增adaface模型example代码

* 删除无用的import

* update

* 修正faceid文档的错误

* 修正faceid文档的错误

* 删除无用文件

* 新增adaface模型paddleinference推理代码,模型文件先提交方便测试后期会删除

* 新增adaface模型paddleinference推理代码,模型文件先提交方便测试后期会删除

* 按照要求修改并跑通cpp example

* 测试python example

* python cpu测试通过,修改了文档

* 修正文档,替换了模型下载地址

* 修正文档

* 修正文档

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
Zheng_Bicheng
2022-10-11 09:55:18 +08:00
committed by GitHub
parent a6847c5432
commit 9c3ac8f0da
18 changed files with 804 additions and 31 deletions

View File

@@ -4,8 +4,9 @@
FastDeploy目前支持如下人脸识别模型部署 FastDeploy目前支持如下人脸识别模型部署
| 模型 | 说明 | 模型格式 | 版本 | | 模型 | 说明 | 模型格式 | 版本 |
| :--- | :--- | :------- | :--- | |:---------------------------------------|:---------------|:-----------|:------------------------------------------------------------------------------|
| [deepinsight/ArcFace](./insightface) | ArcFace 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) | | [deepinsight/ArcFace](./insightface) | ArcFace 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) |
| [deepinsight/CosFace](./insightface) | CosFace 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) | | [deepinsight/CosFace](./insightface) | CosFace 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) |
| [deepinsight/PartialFC](./insightface) | PartialFC 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) | | [deepinsight/PartialFC](./insightface) | PartialFC 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) |
| [deepinsight/VPL](./insightface) | VPL 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) | | [deepinsight/VPL](./insightface) | VPL 系列模型 | ONNX | [CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5) |
| [paddleclas/AdaFace](./adaface) | AdaFace 系列模型 | PADDLE | [CommitID:babb9a5](https://github.com/PaddlePaddle/PaddleClas/tree/v2.4.0) |

View File

@@ -0,0 +1,32 @@
# AdaFace准备部署模型
- [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/)
- [官方库](https://github.com/PaddlePaddle/PaddleClas/)中训练过后的Paddle模型导出Paddle静态图模型操作后可进行部署
## 简介
一直以来,低质量图像的人脸识别都具有挑战性,因为低质量图像的人脸属性是模糊和退化的。将这样的图片输入模型时,将不能很好的实现分类。
而在人脸识别任务中我们经常会利用opencv的仿射变换来矫正人脸数据这时数据会出现低质量退化的现象。如何解决低质量图片的分类问题成为了模型落地时的痛点问题。
在AdaFace这项工作中作者在损失函数中引入了另一个因素即图像质量。作者认为强调错误分类样本的策略应根据其图像质量进行调整。
具体来说,简单或困难样本的相对重要性应该基于样本的图像质量来给定。据此作者提出了一种新的损失函数来通过图像质量强调不同的困难样本的重要性。
由上AdaFace缓解了低质量图片在输入网络后输出结果精度变低的情况更加适合在人脸识别任务落地中使用。
## 导出Paddle静态图模型
以AdaFace为例:
训练和导出代码,请参考[AIStudio](https://aistudio.baidu.com/aistudio/projectdetail/4479879?contributionType=1)
## 下载预训练Paddle静态图模型
为了方便开发者的测试下面提供了我转换过的各系列模型开发者可直接下载使用。下表中模型的精度来源于源官方库其中精度指标来源于AIStudio中对各模型的介绍。
| 模型 | 大小 | 精度 (AgeDB_30) |
|:----------------------------------------------------------------------------------------------|:------|:--------------|
| [AdaFace-MobileFacenet](https://bj.bcebos.com/paddlehub/fastdeploy/mobilefacenet_adaface.tgz) | 3.2MB | 95.5 |
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)

View File

@@ -0,0 +1,13 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,111 @@
# AdaFace C++部署示例
本目录下提供infer_xxx.py快速完成AdaFace模型在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
以AdaFace为例提供`infer.cc`快速完成AdaFace在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/quick_start)
以Linux上CPU推理为例在本目录执行如下命令即可完成编译测试
```bash
# “如果预编译库不包含本模型请从最新代码编译SDK”
mkdir build
cd build
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-0.2.1.tgz
tar xvf fastdeploy-linux-x64-0.2.1.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.1
make -j
#下载测试图片
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_0.JPG
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_1.JPG
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_2.JPG
# 如果为Paddle模型运行以下代码
wget https://bj.bcebos.com/paddlehub/fastdeploy/mobilefacenet_adaface.tgz
tar zxvf mobilefacenet_adaface.tgz -C ./
# CPU推理
./infer_demo mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
test_lite_focal_arcface_0.JPG \
test_lite_focal_arcface_1.JPG \
test_lite_focal_arcface_2.JPG \
0
# GPU推理
./infer_demo mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
test_lite_focal_arcface_0.JPG \
test_lite_focal_arcface_1.JPG \
test_lite_focal_arcface_2.JPG \
1
# GPU上TensorRT推理
./infer_demo mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
test_lite_focal_arcface_0.JPG \
test_lite_focal_arcface_1.JPG \
test_lite_focal_arcface_2.JPG \
2
```
运行完成可视化结果如下图所示
<div width="700">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184321537-860bf857-0101-4e92-a74c-48e8658d838c.JPG">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184322004-a551e6e4-6f47-454e-95d6-f8ba2f47b516.JPG">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184321622-d9a494c3-72f3-47f1-97c5-8a2372de491f.JPG">
</div>
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md)
## AdaFace C++接口
### AdaFace类
```c++
fastdeploy::vision::faceid::AdaFace(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
AdaFace模型加载和初始化如果使用PaddleInference推理model_file和params_file为PaddleInference模型格式;
如果使用ONNXRuntime推理model_file为ONNX模型格式,params_file为空。
#### Predict函数
> ```c++
> AdaFace::Predict(cv::Mat* im, FaceRecognitionResult* result)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, FaceRecognitionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
### 类成员变量
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **size**(vector&lt;int&gt;): 通过此参数修改预处理过程中resize的大小包含两个整型元素表示[width, height], 默认值为[112, 112]
> > * **alpha**(vector&lt;float&gt;): 预处理归一化的alpha值计算公式为`x'=x*alpha+beta`alpha默认为[1. / 127.5, 1.f / 127.5, 1. / 127.5]
> > * **beta**(vector&lt;float&gt;): 预处理归一化的beta值计算公式为`x'=x*alpha+beta`beta默认为[-1.f, -1.f, -1.f]
> > * **swap_rb**(bool): 预处理是否将BGR转换成RGB默认true
> > * **l2_normalize**(bool): 输出人脸向量之前是否执行l2归一化默认false
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)

View File

@@ -0,0 +1,152 @@
/***************************************************************************
*
* Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
*
**************************************************************************/
/**
* @author Baidu
* @brief demo_image_inference
*
**/
#include "fastdeploy/vision.h"
void CpuInfer(const std::string &model_file, const std::string &params_file,
const std::vector<std::string> &image_file) {
auto option = fastdeploy::RuntimeOption();
auto model = fastdeploy::vision::faceid::AdaFace(model_file, params_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
cv::Mat face0 = cv::imread(image_file[0]);
cv::Mat face1 = cv::imread(image_file[1]);
cv::Mat face2 = cv::imread(image_file[2]);
fastdeploy::vision::FaceRecognitionResult res0;
fastdeploy::vision::FaceRecognitionResult res1;
fastdeploy::vision::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
}
std::cout << "Prediction Done!" << std::endl;
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
float cosine01 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res1.embedding, model.l2_normalize);
float cosine02 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res2.embedding, model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
}
void GpuInfer(const std::string &model_file, const std::string &params_file,
const std::vector<std::string> &image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model =
fastdeploy::vision::faceid::AdaFace(model_file, params_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
cv::Mat face0 = cv::imread(image_file[0]);
cv::Mat face1 = cv::imread(image_file[1]);
cv::Mat face2 = cv::imread(image_file[2]);
fastdeploy::vision::FaceRecognitionResult res0;
fastdeploy::vision::FaceRecognitionResult res1;
fastdeploy::vision::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
}
std::cout << "Prediction Done!" << std::endl;
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
float cosine01 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res1.embedding, model.l2_normalize);
float cosine02 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res2.embedding, model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
}
void TrtInfer(const std::string &model_file, const std::string &params_file,
const std::vector<std::string> &image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("data", {1, 3, 112, 112});
auto model =
fastdeploy::vision::faceid::AdaFace(model_file, params_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
cv::Mat face0 = cv::imread(image_file[0]);
cv::Mat face1 = cv::imread(image_file[1]);
cv::Mat face2 = cv::imread(image_file[2]);
fastdeploy::vision::FaceRecognitionResult res0;
fastdeploy::vision::FaceRecognitionResult res1;
fastdeploy::vision::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
}
std::cout << "Prediction Done!" << std::endl;
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
float cosine01 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res1.embedding, model.l2_normalize);
float cosine02 = fastdeploy::vision::utils::CosineSimilarity(
res0.embedding, res2.embedding, model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
}
int main(int argc, char *argv[]) {
if (argc < 7) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_demo mobilefacenet_adaface.pdmodel "
"mobilefacenet_adaface.pdiparams "
"test_lite_focal_AdaFace_0.JPG test_lite_focal_AdaFace_1.JPG "
"test_lite_focal_AdaFace_2.JPG 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
std::vector<std::string> image_files = {argv[3], argv[4], argv[5]};
if (std::atoi(argv[6]) == 0) {
std::cout << "use CpuInfer" << std::endl;
CpuInfer(argv[1], argv[2], image_files);
} else if (std::atoi(argv[6]) == 1) {
GpuInfer(argv[1], argv[2], image_files);
} else if (std::atoi(argv[6]) == 2) {
TrtInfer(argv[1], argv[2], image_files);
}
return 0;
}

View File

@@ -0,0 +1,114 @@
# AdaFace Python部署示例
本目录下提供infer_xxx.py快速完成AdaFace模型在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start)
以AdaFace为例子, 提供`infer.py`快速完成AdaFace在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/faceid/adaface/python/
#下载AdaFace模型文件和测试图片
#下载测试图片
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_0.JPG
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_1.JPG
wget https://bj.bcebos.com/paddlehub/test_samples/test_lite_focal_arcface_2.JPG
# 如果为Paddle模型运行以下代码
wget https://bj.bcebos.com/paddlehub/fastdeploy/mobilefacenet_adaface.tgz
tar zxvf mobilefacenet_adaface.tgz -C ./
# CPU推理
python infer.py --model mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
--params_file mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
--face test_lite_focal_arcface_0.JPG \
--face_positive test_lite_focal_arcface_1.JPG \
--face_negative test_lite_focal_arcface_2.JPG \
--device cpu
# GPU推理
python infer.py --model mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
--params_file mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
--face test_lite_focal_arcface_0.JPG \
--face_positive test_lite_focal_arcface_1.JPG \
--face_negative test_lite_focal_arcface_2.JPG \
--device gpu
# GPU上使用TensorRT推理
python infer.py --model mobilefacenet_adaface/mobilefacenet_adaface.pdmodel \
--params_file mobilefacenet_adaface/mobilefacenet_adaface.pdiparams \
--face test_lite_focal_arcface_0.JPG \
--face_positive test_lite_focal_arcface_1.JPG \
--face_negative test_lite_focal_arcface_2.JPG \
--device gpu \
--use_trt True
```
运行完成可视化结果如下图所示
<div width="700">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184321537-860bf857-0101-4e92-a74c-48e8658d838c.JPG">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184322004-a551e6e4-6f47-454e-95d6-f8ba2f47b516.JPG">
<img width="220" float="left" src="https://user-images.githubusercontent.com/67993288/184321622-d9a494c3-72f3-47f1-97c5-8a2372de491f.JPG">
</div>
```bash
FaceRecognitionResult: [Dim(512), Min(-0.133213), Max(0.148838), Mean(0.000293)]
FaceRecognitionResult: [Dim(512), Min(-0.102777), Max(0.120130), Mean(0.000615)]
FaceRecognitionResult: [Dim(512), Min(-0.116685), Max(0.142919), Mean(0.001595)]
Cosine 01: 0.7483505506964364
Cosine 02: -0.09605773855893639
```
## AdaFace Python接口
```python
fastdeploy.vision.faceid.AdaFace(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.PADDLE)
```
AdaFace模型加载和初始化其中model_file为导出的ONNX模型格式或PADDLE静态图格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为PADDLE
### predict函数
> ```python
> AdaFace.predict(image_data)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> **返回**
>
> > 返回`fastdeploy.vision.FaceRecognitionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
### 类成员属性
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
> > * **size**(list[int]): 通过此参数修改预处理过程中resize的大小包含两个整型元素表示[width, height], 默认值为[112, 112]
> > * **alpha**(list[float]): 预处理归一化的alpha值计算公式为`x'=x*alpha+beta`alpha默认为[1. / 127.5, 1.f / 127.5, 1. / 127.5]
> > * **beta**(list[float]): 预处理归一化的beta值计算公式为`x'=x*alpha+beta`beta默认为[-1.f, -1.f, -1.f]
> > * **swap_rb**(bool): 预处理是否将BGR转换成RGB默认True
> > * **l2_normalize**(bool): 输出人脸向量之前是否执行l2归一化默认False
## 其它文档
- [AdaFace 模型介绍](..)
- [AdaFace C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)

View File

@@ -0,0 +1,92 @@
import fastdeploy as fd
import cv2
import numpy as np
# 余弦相似度
def cosine_similarity(a, b):
a = np.array(a)
b = np.array(b)
mul_a = np.linalg.norm(a, ord=2)
mul_b = np.linalg.norm(b, ord=2)
mul_ab = np.dot(a, b)
return mul_ab / (np.sqrt(mul_a) * np.sqrt(mul_b))
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
required=True,
help="Path of insgihtface paddle or onnx model.")
parser.add_argument(
"--params_file",
default=None,
help="Path of insgihtface paddle model's params_file.")
parser.add_argument(
"--face", required=True, help="Path of test face image file.")
parser.add_argument(
"--face_positive",
required=True,
help="Path of test face_positive image file.")
parser.add_argument(
"--face_negative",
required=True,
help="Path of test face_negative image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("data", [1, 3, 112, 112])
return option
if __name__ == "__main__":
args = parse_arguments()
runtime_option = build_option(args)
model = fd.vision.faceid.AdaFace(
args.model, args.params_file, runtime_option=runtime_option)
face0 = cv2.imread(args.face)
face1 = cv2.imread(args.face_positive)
face2 = cv2.imread(args.face_negative)
model.l2_normalize = True
result0 = model.predict(face0)
result1 = model.predict(face1)
result2 = model.predict(face2)
embedding0 = result0.embedding
embedding1 = result1.embedding
embedding2 = result2.embedding
cosine01 = cosine_similarity(embedding0, embedding1)
cosine02 = cosine_similarity(embedding0, embedding2)
print(result0, end="")
print(result1, end="")
print(result2, end="")
print("Cosine 01: ", cosine01)
print("Cosine 02: ", cosine02)
print(model.runtime_option)

View File

@@ -32,7 +32,7 @@
为了方便开发者的测试下面提供了InsightFace导出的各系列模型开发者可直接下载使用。下表中模型的精度来源于源官方库其中精度指标来源于InsightFace中对各模型的介绍详情各参考InsightFace中的说明 为了方便开发者的测试下面提供了InsightFace导出的各系列模型开发者可直接下载使用。下表中模型的精度来源于源官方库其中精度指标来源于InsightFace中对各模型的介绍详情各参考InsightFace中的说明
| 模型 | 大小 | 精度 (AgeDB_30) | | 模型 | 大小 | 精度 (AgeDB_30) |
|:---------------------------------------------------------------- |:----- |:----- | |:-------------------------------------------------------------------------------------------|:------|:--------------|
| [CosFace-r18](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r18.onnx) | 92MB | 97.7 | | [CosFace-r18](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r18.onnx) | 92MB | 97.7 |
| [CosFace-r34](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r34.onnx) | 131MB | 98.3 | | [CosFace-r34](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r34.onnx) | 131MB | 98.3 |
| [CosFace-r50](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r50.onnx) | 167MB | 98.3 | | [CosFace-r50](https://bj.bcebos.com/paddlehub/fastdeploy/glint360k_cosface_r50.onnx) | 167MB | 98.3 |

View File

@@ -102,9 +102,7 @@ VPL模型加载和初始化其中model_file为导出的ONNX模型格式。
#### Predict函数 #### Predict函数
> ```c++ > ```c++
> ArcFace::Predict(cv::Mat* im, FaceRecognitionResult* result, > ArcFace::Predict(cv::Mat* im, FaceRecognitionResult* result)
> float conf_threshold = 0.25,
> float nms_iou_threshold = 0.5)
> ``` > ```
> >
> 模型预测接口,输入图像直接输出检测结果。 > 模型预测接口,输入图像直接输出检测结果。
@@ -113,8 +111,6 @@ VPL模型加载和初始化其中model_file为导出的ONNX模型格式。
> >
> > * **im**: 输入图像注意需为HWCBGR格式 > > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, FaceRecognitionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) > > * **result**: 检测结果,包括检测框,各个框的置信度, FaceRecognitionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
> > * **conf_threshold**: 检测框置信度过滤阈值
> > * **nms_iou_threshold**: NMS处理过程中iou阈值
### 类成员变量 ### 类成员变量
#### 预处理参数 #### 预处理参数

View File

@@ -65,7 +65,7 @@ ArcFace模型加载和初始化其中model_file为导出的ONNX模型格式
### predict函数 ### predict函数
> ```python > ```python
> ArcFace.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) > ArcFace.predict(image_data)
> ``` > ```
> >
> 模型预测结口,输入图像直接输出检测结果。 > 模型预测结口,输入图像直接输出检测结果。
@@ -73,8 +73,6 @@ ArcFace模型加载和初始化其中model_file为导出的ONNX模型格式
> **参数** > **参数**
> >
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式 > > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
> **返回** > **返回**
> >

View File

@@ -23,14 +23,15 @@
#include "fastdeploy/vision/detection/contrib/yolov5lite.h" #include "fastdeploy/vision/detection/contrib/yolov5lite.h"
#include "fastdeploy/vision/detection/contrib/yolov6.h" #include "fastdeploy/vision/detection/contrib/yolov6.h"
#include "fastdeploy/vision/detection/contrib/yolov7.h" #include "fastdeploy/vision/detection/contrib/yolov7.h"
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h" #include "fastdeploy/vision/detection/contrib/yolov7end2end_ort.h"
#include "fastdeploy/vision/detection/contrib/yolov7end2end_trt.h"
#include "fastdeploy/vision/detection/contrib/yolox.h" #include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/facedet/contrib/retinaface.h" #include "fastdeploy/vision/facedet/contrib/retinaface.h"
#include "fastdeploy/vision/facedet/contrib/scrfd.h" #include "fastdeploy/vision/facedet/contrib/scrfd.h"
#include "fastdeploy/vision/facedet/contrib/ultraface.h" #include "fastdeploy/vision/facedet/contrib/ultraface.h"
#include "fastdeploy/vision/facedet/contrib/yolov5face.h" #include "fastdeploy/vision/facedet/contrib/yolov5face.h"
#include "fastdeploy/vision/faceid/contrib/adaface.h"
#include "fastdeploy/vision/faceid/contrib/arcface.h" #include "fastdeploy/vision/faceid/contrib/arcface.h"
#include "fastdeploy/vision/faceid/contrib/cosface.h" #include "fastdeploy/vision/faceid/contrib/cosface.h"
#include "fastdeploy/vision/faceid/contrib/insightface_rec.h" #include "fastdeploy/vision/faceid/contrib/insightface_rec.h"

View File

@@ -0,0 +1,74 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/faceid/contrib/adaface.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace faceid {
AdaFace::AdaFace(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format)
: InsightFaceRecognitionModel(model_file, params_file, custom_option,
model_format) {
initialized = Initialize();
}
bool AdaFace::Initialize() {
// (1) if parent class initialed backend
if (initialized) {
// (1.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
// (2) if parent class not initialed backend
if (!InsightFaceRecognitionModel::Initialize()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// (2.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
bool AdaFace::Preprocess(Mat* mat, FDTensor* output) {
return InsightFaceRecognitionModel::Preprocess(mat, output);
}
bool AdaFace::Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) {
return InsightFaceRecognitionModel::Postprocess(infer_result, result);
}
bool AdaFace::Predict(cv::Mat* im, FaceRecognitionResult* result) {
return InsightFaceRecognitionModel::Predict(im, result);
}
} // namespace faceid
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,50 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/faceid/contrib/insightface_rec.h"
namespace fastdeploy {
namespace vision {
namespace faceid {
class FASTDEPLOY_DECL AdaFace : public InsightFaceRecognitionModel {
public:
AdaFace(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
std::string ModelName() const override {
return "Zheng-Bicheng/AdaFacePaddleCLas";
}
bool Predict(cv::Mat* im, FaceRecognitionResult* result) override;
private:
bool Initialize() override;
bool Preprocess(Mat* mat, FDTensor* output) override;
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) override;
};
} // namespace faceid
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,38 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindAdaFace(pybind11::module& m) {
// Bind AdaFace
pybind11::class_<vision::faceid::AdaFace,
vision::faceid::InsightFaceRecognitionModel>(m, "AdaFace")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::faceid::AdaFace& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::faceid::AdaFace::size)
.def_readwrite("alpha", &vision::faceid::AdaFace::alpha)
.def_readwrite("beta", &vision::faceid::AdaFace::beta)
.def_readwrite("swap_rb", &vision::faceid::AdaFace::swap_rb)
.def_readwrite("l2_normalize", &vision::faceid::AdaFace::l2_normalize);
}
} // namespace fastdeploy

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/faceid/contrib/insightface_rec.h" #include "fastdeploy/vision/faceid/contrib/insightface_rec.h"
#include "fastdeploy/utils/perf.h" #include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h" #include "fastdeploy/vision/utils/utils.h"

View File

@@ -15,7 +15,7 @@
#include "fastdeploy/pybind/main.h" #include "fastdeploy/pybind/main.h"
namespace fastdeploy { namespace fastdeploy {
void BindAdaFace(pybind11::module& m);
void BindArcFace(pybind11::module& m); void BindArcFace(pybind11::module& m);
void BindInsightFaceRecognitionModel(pybind11::module& m); void BindInsightFaceRecognitionModel(pybind11::module& m);
void BindCosFace(pybind11::module& m); void BindCosFace(pybind11::module& m);
@@ -25,6 +25,7 @@ void BindVPL(pybind11::module& m);
void BindFaceId(pybind11::module& m) { void BindFaceId(pybind11::module& m) {
auto faceid_module = m.def_submodule("faceid", "Face recognition models."); auto faceid_module = m.def_submodule("faceid", "Face recognition models.");
BindInsightFaceRecognitionModel(faceid_module); BindInsightFaceRecognitionModel(faceid_module);
BindAdaFace(faceid_module);
BindArcFace(faceid_module); BindArcFace(faceid_module);
BindCosFace(faceid_module); BindCosFace(faceid_module);
BindPartialFC(faceid_module); BindPartialFC(faceid_module);

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from .contrib.adaface import AdaFace
from .contrib.arcface import ArcFace from .contrib.arcface import ArcFace
from .contrib.cosface import CosFace from .contrib.cosface import CosFace
from .contrib.insightface_rec import InsightFaceRecognitionModel from .contrib.insightface_rec import InsightFaceRecognitionModel

View File

@@ -0,0 +1,98 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class AdaFace(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=ModelFormat.PADDLE):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(AdaFace, self).__init__(runtime_option)
self._model = C.vision.faceid.AdaFace(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "AdaFace initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, (list, tuple)), \
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2, \
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, (list, tuple)), \
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3, \
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, (list, tuple)), \
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3, \
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value