Merge branch 'develop' into new_ppdet1

This commit is contained in:
Jason
2022-08-05 09:27:16 +08:00
committed by GitHub
31 changed files with 2194 additions and 7 deletions

View File

@@ -17,7 +17,12 @@
#ifdef ENABLE_VISION
#include "fastdeploy/vision/biubug6/retinaface.h"
#include "fastdeploy/vision/deepcam/yolov5face.h"
#include "fastdeploy/vision/deepinsight/arcface.h"
#include "fastdeploy/vision/deepinsight/cosface.h"
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
#include "fastdeploy/vision/deepinsight/partial_fc.h"
#include "fastdeploy/vision/deepinsight/scrfd.h"
#include "fastdeploy/vision/deepinsight/vpl.h"
#include "fastdeploy/vision/linzaer/ultraface.h"
#include "fastdeploy/vision/megvii/yolox.h"
#include "fastdeploy/vision/meituan/yolov6.h"

View File

@@ -187,5 +187,43 @@ std::string SegmentationResult::Str() {
return out;
}
FaceRecognitionResult::FaceRecognitionResult(const FaceRecognitionResult& res) {
embedding.assign(res.embedding.begin(), res.embedding.end());
}
void FaceRecognitionResult::Clear() { std::vector<float>().swap(embedding); }
void FaceRecognitionResult::Reserve(int size) { embedding.reserve(size); }
void FaceRecognitionResult::Resize(int size) { embedding.resize(size); }
std::string FaceRecognitionResult::Str() {
std::string out;
out = "FaceRecognitionResult: [";
size_t numel = embedding.size();
if (numel <= 0) {
return out + "Empty Result]";
}
// max, min, mean
float min_val = embedding.at(0);
float max_val = embedding.at(0);
float total_val = embedding.at(0);
for (size_t i = 1; i < numel; ++i) {
float val = embedding.at(i);
total_val += val;
if (val < min_val) {
min_val = val;
}
if (val > max_val) {
max_val = val;
}
}
float mean_val = total_val / static_cast<float>(numel);
out = out + "Dim(" + std::to_string(numel) + "), " + "Min(" +
std::to_string(min_val) + "), " + "Max(" + std::to_string(max_val) +
"), " + "Mean(" + std::to_string(mean_val) + ")]\n";
return out;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -22,7 +22,8 @@ enum FASTDEPLOY_DECL ResultType {
CLASSIFY,
DETECTION,
SEGMENTATION,
FACE_DETECTION
FACE_DETECTION,
FACE_RECOGNITION
};
struct FASTDEPLOY_DECL BaseResult {
@@ -100,5 +101,23 @@ struct FASTDEPLOY_DECL SegmentationResult : public BaseResult {
std::string Str();
};
struct FASTDEPLOY_DECL FaceRecognitionResult : public BaseResult {
// face embedding vector with 128/256/512 ... dim
std::vector<float> embedding;
ResultType type = ResultType::FACE_RECOGNITION;
FaceRecognitionResult() {}
FaceRecognitionResult(const FaceRecognitionResult& res);
void Clear();
void Reserve(int size);
void Resize(int size);
std::string Str();
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/arcface.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
ArcFace::ArcFace(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format)
: InsightFaceRecognitionModel(model_file, params_file, custom_option,
model_format) {
initialized = Initialize();
}
bool ArcFace::Initialize() {
// 如果初始化有变化 修改该子类函数
// 这里需要判断backend是否已经initialized如果是则不应该再调用
// InsightFaceRecognitionModel::Initialize()
// 因为该函数会对backend进行初始化, backend已经在父类的构造函数初始化
// 这里只修改一些模型相关的属性
// (1) 如果父类初始化了backend
if (initialized) {
// (1.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
// (2) 如果父类没有初始化backend
if (!InsightFaceRecognitionModel::Initialize()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// (2.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
bool ArcFace::Preprocess(Mat* mat, FDTensor* output) {
// 如果预处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Preprocess(mat, output);
}
bool ArcFace::Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) {
// 如果后处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Postprocess(infer_result, result);
}
bool ArcFace::Predict(cv::Mat* im, FaceRecognitionResult* result) {
// 如果前后处理有变化 则override子类的Preprocess和Postprocess
// 如果前后处理有变化 此处应该调用子类自己的Preprocess和Postprocess
return InsightFaceRecognitionModel::Predict(im, result);
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL ArcFace : public InsightFaceRecognitionModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
// ArcFace支持IResNet, IResNet2060, VIT, MobileFaceNet骨干
ArcFace(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
std::string ModelName() const override {
return "deepinsight/insightface/recognition/arcface_pytorch";
}
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
bool Predict(cv::Mat* im, FaceRecognitionResult* result) override;
// 父类中包含 size, alpha, beta, swap_rb, l2_normalize 等基本可配置属性
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize() override;
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
bool Preprocess(Mat* mat, FDTensor* output) override;
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) override;
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/cosface.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
CosFace::CosFace(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format)
: InsightFaceRecognitionModel(model_file, params_file, custom_option,
model_format) {
initialized = Initialize();
}
bool CosFace::Initialize() {
// 如果初始化有变化 修改该子类函数
// 这里需要判断backend是否已经initialized如果是则不应该再调用
// InsightFaceRecognitionModel::Initialize()
// 因为该函数会对backend进行初始化, backend已经在父类的构造函数初始化
// 这里只修改一些模型相关的属性
// (1) 如果父类初始化了backend
if (initialized) {
// (1.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
// (2) 如果父类没有初始化backend
if (!InsightFaceRecognitionModel::Initialize()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// (2.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
bool CosFace::Preprocess(Mat* mat, FDTensor* output) {
// 如果预处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Preprocess(mat, output);
}
bool CosFace::Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) {
// 如果后处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Postprocess(infer_result, result);
}
bool CosFace::Predict(cv::Mat* im, FaceRecognitionResult* result) {
// 如果前后处理有变化 则override子类的Preprocess和Postprocess
// 如果前后处理有变化 此处应该调用子类自己的Preprocess和Postprocess
return InsightFaceRecognitionModel::Predict(im, result);
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL CosFace : public InsightFaceRecognitionModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
// ArcFace支持IResNet, IResNet2060, VIT, MobileFaceNet骨干
CosFace(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
// insightface/arcface提供的模型文件包含了cosface
std::string ModelName() const override {
return "deepinsight/insightface/recognition/arcface_pytorch";
}
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
bool Predict(cv::Mat* im, FaceRecognitionResult* result) override;
// 父类中包含 size, alpha, beta, swap_rb, l2_normalize 等基本可配置属性
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize() override;
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
bool Preprocess(Mat* mat, FDTensor* output) override;
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) override;
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -15,9 +15,10 @@
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindDeepinsight(pybind11::module& m) {
auto deepinsight_module = m.def_submodule(
"deepinsight", "https://github.com/deepinsight");
void BindDeepInsight(pybind11::module& m) {
auto deepinsight_module =
m.def_submodule("deepinsight", "https://github.com/deepinsight");
// Bind SCRFD
pybind11::class_<vision::deepinsight::SCRFD, FastDeployModel>(
deepinsight_module, "SCRFD")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
@@ -43,5 +44,101 @@ void BindDeepinsight(pybind11::module& m) {
.def_readwrite("num_anchors", &vision::deepinsight::SCRFD::num_anchors)
.def_readwrite("landmarks_per_face",
&vision::deepinsight::SCRFD::landmarks_per_face);
// Bind InsightFaceRecognitionModel
pybind11::class_<vision::deepinsight::InsightFaceRecognitionModel,
FastDeployModel>(deepinsight_module,
"InsightFaceRecognitionModel")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::InsightFaceRecognitionModel& self,
pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size",
&vision::deepinsight::InsightFaceRecognitionModel::size)
.def_readwrite("alpha",
&vision::deepinsight::InsightFaceRecognitionModel::alpha)
.def_readwrite("beta",
&vision::deepinsight::InsightFaceRecognitionModel::beta)
.def_readwrite("swap_rb",
&vision::deepinsight::InsightFaceRecognitionModel::swap_rb)
.def_readwrite(
"l2_normalize",
&vision::deepinsight::InsightFaceRecognitionModel::l2_normalize);
// Bind ArcFace
pybind11::class_<vision::deepinsight::ArcFace,
vision::deepinsight::InsightFaceRecognitionModel>(
deepinsight_module, "ArcFace")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::ArcFace& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::deepinsight::ArcFace::size)
.def_readwrite("alpha", &vision::deepinsight::ArcFace::alpha)
.def_readwrite("beta", &vision::deepinsight::ArcFace::beta)
.def_readwrite("swap_rb", &vision::deepinsight::ArcFace::swap_rb)
.def_readwrite("l2_normalize",
&vision::deepinsight::ArcFace::l2_normalize);
// Bind CosFace
pybind11::class_<vision::deepinsight::CosFace,
vision::deepinsight::InsightFaceRecognitionModel>(
deepinsight_module, "CosFace")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::CosFace& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::deepinsight::CosFace::size)
.def_readwrite("alpha", &vision::deepinsight::CosFace::alpha)
.def_readwrite("beta", &vision::deepinsight::CosFace::beta)
.def_readwrite("swap_rb", &vision::deepinsight::CosFace::swap_rb)
.def_readwrite("l2_normalize",
&vision::deepinsight::CosFace::l2_normalize);
// Bind Partial FC
pybind11::class_<vision::deepinsight::PartialFC,
vision::deepinsight::InsightFaceRecognitionModel>(
deepinsight_module, "PartialFC")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::PartialFC& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::deepinsight::PartialFC::size)
.def_readwrite("alpha", &vision::deepinsight::PartialFC::alpha)
.def_readwrite("beta", &vision::deepinsight::PartialFC::beta)
.def_readwrite("swap_rb", &vision::deepinsight::PartialFC::swap_rb)
.def_readwrite("l2_normalize",
&vision::deepinsight::PartialFC::l2_normalize);
// Bind VPL
pybind11::class_<vision::deepinsight::VPL,
vision::deepinsight::InsightFaceRecognitionModel>(
deepinsight_module, "VPL")
.def(pybind11::init<std::string, std::string, RuntimeOption, Frontend>())
.def("predict",
[](vision::deepinsight::VPL& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::FaceRecognitionResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::deepinsight::VPL::size)
.def_readwrite("alpha", &vision::deepinsight::VPL::alpha)
.def_readwrite("beta", &vision::deepinsight::VPL::beta)
.def_readwrite("swap_rb", &vision::deepinsight::VPL::swap_rb)
.def_readwrite("l2_normalize", &vision::deepinsight::VPL::l2_normalize);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,153 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
InsightFaceRecognitionModel::InsightFaceRecognitionModel(
const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option, const Frontend& model_format) {
if (model_format == Frontend::ONNX) {
valid_cpu_backends = {Backend::ORT}; // 指定可用的CPU后端
valid_gpu_backends = {Backend::ORT, Backend::TRT}; // 指定可用的GPU后端
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool InsightFaceRecognitionModel::Initialize() {
// parameters for preprocess
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool InsightFaceRecognitionModel::Preprocess(Mat* mat, FDTensor* output) {
// face recognition model's preprocess steps in insightface
// reference: insightface/recognition/arcface_torch/inference.py
// 1. Resize
// 2. BGR2RGB
// 3. Convert(opencv style) or Normalize
// 4. HWC2CHW
int resize_w = size[0];
int resize_h = size[1];
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}
if (swap_rb) {
BGR2RGB::Run(mat);
}
Convert::Run(mat, alpha, beta);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}
bool InsightFaceRecognitionModel::Postprocess(
std::vector<FDTensor>& infer_result, FaceRecognitionResult* result) {
FDASSERT((infer_result.size() == 1),
"The default number of output tensor must be 1 according to "
"insightface.");
FDTensor& embedding_tensor = infer_result.at(0);
FDASSERT((embedding_tensor.shape[0] == 1), "Only support batch =1 now.");
if (embedding_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
result->Clear();
result->Resize(embedding_tensor.Numel());
// Copy the raw embedding vector directly without L2 normalize
// post process. Let the user decide whether to normalize or not.
// Will call utils::L2Normlize() method to perform L2
// normalize if l2_normalize was set as 'true'.
std::memcpy(result->embedding.data(), embedding_tensor.Data(),
embedding_tensor.Nbytes());
if (l2_normalize) {
auto norm_embedding = utils::L2Normalize(result->embedding);
std::memcpy(result->embedding.data(), norm_embedding.data(),
embedding_tensor.Nbytes());
}
return true;
}
bool InsightFaceRecognitionModel::Predict(cv::Mat* im,
FaceRecognitionResult* result) {
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_START(0)
#endif
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0])) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(0, "Preprocess")
TIMERECORD_START(1)
#endif
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(1, "Inference")
TIMERECORD_START(2)
#endif
if (!Postprocess(output_tensors, result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
#ifdef FASTDEPLOY_DEBUG
TIMERECORD_END(2, "Postprocess")
#endif
return true;
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,72 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL InsightFaceRecognitionModel : public FastDeployModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
// 支持insightface/recognition人脸识别模型的基类
InsightFaceRecognitionModel(
const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
virtual std::string ModelName() const { return "deepinsight/insightface"; }
// 以下为一些可供用户修改的属性
// tuple of (width, height), default (112, 112)
std::vector<int> size;
// 归一化的 alpha 和 betax'=x*alpha+beta
std::vector<float> alpha;
std::vector<float> beta;
// whether to swap the B and R channel, such as BGR->RGB, default true.
bool swap_rb;
// whether to apply l2 normalize to embedding values, default;
bool l2_normalize;
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
virtual bool Predict(cv::Mat* im, FaceRecognitionResult* result);
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
virtual bool Initialize();
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
virtual bool Preprocess(Mat* mat, FDTensor* output);
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
virtual bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result);
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/partial_fc.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
PartialFC::PartialFC(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const Frontend& model_format)
: InsightFaceRecognitionModel(model_file, params_file, custom_option,
model_format) {
initialized = Initialize();
}
bool PartialFC::Initialize() {
// 如果初始化有变化 修改该子类函数
// 这里需要判断backend是否已经initialized如果是则不应该再调用
// InsightFaceRecognitionModel::Initialize()
// 因为该函数会对backend进行初始化, backend已经在父类的构造函数初始化
// 这里只修改一些模型相关的属性
// (1) 如果父类初始化了backend
if (initialized) {
// (1.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
// (2) 如果父类没有初始化backend
if (!InsightFaceRecognitionModel::Initialize()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// (2.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
bool PartialFC::Preprocess(Mat* mat, FDTensor* output) {
// 如果预处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Preprocess(mat, output);
}
bool PartialFC::Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) {
// 如果后处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Postprocess(infer_result, result);
}
bool PartialFC::Predict(cv::Mat* im, FaceRecognitionResult* result) {
// 如果前后处理有变化 则override子类的Preprocess和Postprocess
// 如果前后处理有变化 此处应该调用子类自己的Preprocess和Postprocess
return InsightFaceRecognitionModel::Predict(im, result);
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL PartialFC : public InsightFaceRecognitionModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
PartialFC(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
std::string ModelName() const override {
return "deepinsight/insightface/recognition/partial_fc";
}
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
bool Predict(cv::Mat* im, FaceRecognitionResult* result) override;
// 父类中包含 size, alpha, beta, swap_rb, l2_normalize 等基本可配置属性
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize() override;
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
bool Preprocess(Mat* mat, FDTensor* output) override;
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) override;
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/deepinsight/vpl.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
VPL::VPL(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option, const Frontend& model_format)
: InsightFaceRecognitionModel(model_file, params_file, custom_option,
model_format) {
initialized = Initialize();
}
bool VPL::Initialize() {
// 如果初始化有变化 修改该子类函数
// 这里需要判断backend是否已经initialized如果是则不应该再调用
// InsightFaceRecognitionModel::Initialize()
// 因为该函数会对backend进行初始化, backend已经在父类的构造函数初始化
// 这里只修改一些模型相关的属性
// (1) 如果父类初始化了backend
if (initialized) {
// (1.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
// (2) 如果父类没有初始化backend
if (!InsightFaceRecognitionModel::Initialize()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
// (2.1) re-init parameters for specific sub-classes
size = {112, 112};
alpha = {1.f / 127.5f, 1.f / 127.5f, 1.f / 127.5f};
beta = {-1.f, -1.f, -1.f}; // RGB
swap_rb = true;
l2_normalize = false;
return true;
}
bool VPL::Preprocess(Mat* mat, FDTensor* output) {
// 如果预处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Preprocess(mat, output);
}
bool VPL::Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) {
// 如果后处理有变化 修改该子类函数
return InsightFaceRecognitionModel::Postprocess(infer_result, result);
}
bool VPL::Predict(cv::Mat* im, FaceRecognitionResult* result) {
// 如果前后处理有变化 则override子类的Preprocess和Postprocess
// 如果前后处理有变化 此处应该调用子类自己的Preprocess和Postprocess
return InsightFaceRecognitionModel::Predict(im, result);
}
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,65 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/deepinsight/insightface_rec.h"
namespace fastdeploy {
namespace vision {
namespace deepinsight {
class FASTDEPLOY_DECL VPL : public InsightFaceRecognitionModel {
public:
// 当model_format为ONNX时无需指定params_file
// 当model_format为Paddle时则需同时指定model_file & params_file
// VPL支持IResNet, IResNet1024骨干
VPL(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
// 定义模型的名称
std::string ModelName() const override {
return "deepinsight/insightface/recognition/vpl";
}
// 模型预测接口,即用户调用的接口
// im 为用户的输入数据目前对于CV均定义为cv::Mat
// result 为模型预测的输出结构体
bool Predict(cv::Mat* im, FaceRecognitionResult* result) override;
// 父类中包含 size, alpha, beta, swap_rb, l2_normalize 等基本可配置属性
private:
// 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作
bool Initialize() override;
// 输入图像预处理操作
// Mat为FastDeploy定义的数据结构
// FDTensor为预处理后的Tensor数据传给后端进行推理
bool Preprocess(Mat* mat, FDTensor* output) override;
// 后端推理结果后处理,输出给用户
// infer_result 为后端推理后的输出Tensor
// result 为模型预测的结果
bool Postprocess(std::vector<FDTensor>& infer_result,
FaceRecognitionResult* result) override;
};
} // namespace deepinsight
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,49 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace utils {
float CosineSimilarity(const std::vector<float>& a, const std::vector<float>& b,
bool normalized) {
// 计算余弦相似度
FDASSERT((a.size() == b.size()) && (a.size() != 0),
"The size of a and b must be equal and >= 1.");
size_t num_val = a.size();
if (normalized) {
float mul_a = 0.f, mul_b = 0.f, mul_ab = 0.f;
for (size_t i = 0; i < num_val; ++i) {
mul_a += (a[i] * a[i]);
mul_b += (b[i] * b[i]);
mul_ab += (a[i] * b[i]);
}
return (mul_ab / (std::sqrt(mul_a) * std::sqrt(mul_b)));
}
auto norm_a = L2Normalize(a);
auto norm_b = L2Normalize(b);
float mul_a = 0.f, mul_b = 0.f, mul_ab = 0.f;
for (size_t i = 0; i < num_val; ++i) {
mul_a += (norm_a[i] * norm_a[i]);
mul_b += (norm_b[i] * norm_b[i]);
mul_ab += (norm_a[i] * norm_b[i]);
}
return (mul_ab / (std::sqrt(mul_a) * std::sqrt(mul_b)));
}
} // namespace utils
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace utils {
std::vector<float> L2Normalize(const std::vector<float>& values) {
size_t num_val = values.size();
if (num_val == 0) {
return {};
}
std::vector<float> norm;
float l2_sum_val = 0.f;
for (size_t i = 0; i < num_val; ++i) {
l2_sum_val += (values[i] * values[i]);
}
float l2_sum_sqrt = std::sqrt(l2_sum_val);
norm.resize(num_val);
for (size_t i = 0; i < num_val; ++i) {
norm[i] = values[i] / l2_sum_sqrt;
}
return norm;
}
} // namespace utils
} // namespace vision
} // namespace fastdeploy

View File

@@ -127,6 +127,14 @@ void SortDetectionResult(DetectionResult* output);
void SortDetectionResult(FaceDetectionResult* result);
// L2 Norm / cosine similarity (for face recognition, ...)
FASTDEPLOY_DECL std::vector<float> L2Normalize(
const std::vector<float>& values);
FASTDEPLOY_DECL float CosineSimilarity(const std::vector<float>& a,
const std::vector<float>& b,
bool normalized = true);
} // namespace utils
} // namespace vision
} // namespace fastdeploy

View File

@@ -28,7 +28,7 @@ void BindRangiLyu(pybind11::module& m);
void BindLinzaer(pybind11::module& m);
void BindBiubug6(pybind11::module& m);
void BindPpogg(pybind11::module& m);
void BindDeepinsight(pybind11::module& m);
void BindDeepInsight(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
#endif
@@ -58,6 +58,7 @@ void BindVision(pybind11::module& m) {
&vision::FaceDetectionResult::landmarks_per_face)
.def("__repr__", &vision::FaceDetectionResult::Str)
.def("__str__", &vision::FaceDetectionResult::Str);
pybind11::class_<vision::SegmentationResult>(m, "SegmentationResult")
.def(pybind11::init())
.def_readwrite("label_map", &vision::SegmentationResult::label_map)
@@ -67,6 +68,12 @@ void BindVision(pybind11::module& m) {
.def("__repr__", &vision::SegmentationResult::Str)
.def("__str__", &vision::SegmentationResult::Str);
pybind11::class_<vision::FaceRecognitionResult>(m, "FaceRecognitionResult")
.def(pybind11::init())
.def_readwrite("embedding", &vision::FaceRecognitionResult::embedding)
.def("__repr__", &vision::FaceRecognitionResult::Str)
.def("__str__", &vision::FaceRecognitionResult::Str);
BindPPCls(m);
BindPPDet(m);
BindPPSeg(m);
@@ -79,7 +86,7 @@ void BindVision(pybind11::module& m) {
BindLinzaer(m);
BindBiubug6(m);
BindPpogg(m);
BindDeepinsight(m);
BindDeepInsight(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);
#endif

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "../resources/models/ms1mv3_arcface_r100.onnx";
std::string face0_path = "../resources/images/face_recognition_0.png";
std::string face1_path = "../resources/images/face_recognition_1.png";
std::string face2_path = "../resources/images/face_recognition_2.png";
auto model = vis::deepinsight::ArcFace(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "../resources/models/glint360k_cosface_r100.onnx";
std::string face0_path = "../resources/images/face_recognition_0.png";
std::string face1_path = "../resources/images/face_recognition_1.png";
std::string face2_path = "../resources/images/face_recognition_2.png";
auto model = vis::deepinsight::CosFace(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "../resources/models/ms1mv3_arcface_r100.onnx";
std::string face0_path = "../resources/images/face_recognition_0.png";
std::string face1_path = "../resources/images/face_recognition_1.png";
std::string face2_path = "../resources/images/face_recognition_2.png";
auto model = vis::deepinsight::InsightFaceRecognitionModel(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "../resources/models/partial_fc_glint360k_r100.onnx";
std::string face0_path = "../resources/images/face_recognition_0.png";
std::string face1_path = "../resources/images/face_recognition_1.png";
std::string face2_path = "../resources/images/face_recognition_2.png";
auto model = vis::deepinsight::PartialFC(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "../resources/models/ms1mv3_r100_lr01.onnx";
std::string face0_path = "../resources/images/face_recognition_0.png";
std::string face1_path = "../resources/images/face_recognition_1.png";
std::string face2_path = "../resources/images/face_recognition_2.png";
auto model = vis::deepinsight::VPL(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}

View File

@@ -37,7 +37,7 @@ class RetinaFace(FastDeployModel):
return self._model.predict(input_image, conf_threshold,
nms_iou_threshold)
# 一些跟UltraFace模型有关的属性封装
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [640, 480]改变预处理时resize的大小前提是模型支持
@property
def size(self):

View File

@@ -156,3 +156,408 @@ class SCRFD(FastDeployModel):
assert isinstance(
value, int), "The value to set `num_anchors` must be type of int."
self._model.num_anchors = value
class InsightFaceRecognitionModel(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(InsightFaceRecognitionModel, self).__init__(runtime_option)
self._model = C.vision.deepinsight.InsightFaceRecognitionModel(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "InsightFaceRecognitionModel initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟InsightFaceRecognitionModel模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value
class ArcFace(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(ArcFace, self).__init__(runtime_option)
self._model = C.vision.deepinsight.ArcFace(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "ArcFace initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value
class CosFace(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(CosFace, self).__init__(runtime_option)
self._model = C.vision.deepinsight.CosFace(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "CosFace initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value
class PartialFC(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(PartialFC, self).__init__(runtime_option)
self._model = C.vision.deepinsight.PartialFC(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "PartialFC initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value
class VPL(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=Frontend.ONNX):
# 调用基函数进行backend_option的初始化
# 初始化后的option保存在self._runtime_option
super(VPL, self).__init__(runtime_option)
self._model = C.vision.deepinsight.VPL(
model_file, params_file, self._runtime_option, model_format)
# 通过self.initialized判断整个模型的初始化是否成功
assert self.initialized, "VPL initialize failed."
def predict(self, input_image):
return self._model.predict(input_image)
# 一些跟模型有关的属性封装
# 多数是预处理相关可通过修改如model.size = [112, 112]改变预处理时resize的大小前提是模型支持
@property
def size(self):
return self._model.size
@property
def alpha(self):
return self._model.alpha
@property
def beta(self):
return self._model.beta
@property
def swap_rb(self):
return self._model.swap_rb
@property
def l2_normalize(self):
return self._model.l2_normalize
@size.setter
def size(self, wh):
assert isinstance(wh, [list, tuple]),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@alpha.setter
def alpha(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `alpha` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `alpha` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.alpha = value
@beta.setter
def beta(self, value):
assert isinstance(value, [list, tuple]),\
"The value to set `beta` must be type of tuple or list."
assert len(value) == 3,\
"The value to set `beta` must contatins 3 elements for each channels, but now it contains {} elements.".format(
len(value))
self._model.beta = value
@swap_rb.setter
def swap_rb(self, value):
assert isinstance(
value, bool), "The value to set `swap_rb` must be type of bool."
self._model.swap_rb = value
@l2_normalize.setter
def l2_normalize(self, value):
assert isinstance(
value,
bool), "The value to set `l2_normalize` must be type of bool."
self._model.l2_normalize = value

View File

@@ -0,0 +1,80 @@
# ArcFace部署示例
## 0. 简介
当前支持模型版本为:[ArcFace CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5)
本文档说明如何进行[ArcFace](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch) 的快速部署推理。本目录结构如下
```
.
├── cpp # C++ 代码目录
│   ├── CMakeLists.txt # C++ 代码编译CMakeLists文件
│   ├── README.md # C++ 代码编译部署文档
│   └── arcface.cc # C++ 示例代码
├── api.md # API 说明文档
├── README.md # ArcFace 部署文档
└── arcface.py # Python示例代码
```
## 1. 特别说明
fastdeploy支持 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition) 的人脸识别模块recognition中大部分模型的部署包括ArcFace、CosFace、Partial FC、VPL等由于用法类似这里仅用ArcFace来演示部署流程。所有支持的模型结构可参考 [ArcFace API文档](./api.md).
## 2. 获取ONNX文件
访问[ArcFace](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch)官方github库按照指引下载安装下载pt模型文件利用 `torch2onnx.py` 得到`onnx`格式文件。
* 下载ArcFace模型文件
```
Link: https://pan.baidu.com/share/init?surl=CL-l4zWqsI1oDuEEYVhj-g code: e8pw
```
* 导出onnx格式文件
```bash
PYTHONPATH=. python ./torch2onnx.py partial_fc/pytorch/ms1mv3_arcface_r100_fp16/backbone.pth --output ms1mv3_arcface_r100.onnx --network r100 --simplify 1
```
* 移动onnx文件到model_zoo/arcface的目录
```bash
cp PATH/TO/ms1mv3_arcface_r100.onnx PATH/TO/model_zoo/vision/arcface/
```
## 3. 准备测试图片
准备3张仅包含人脸的测试图片命名为face_recognition_*.jpg并拷贝到可执行文件所在的目录比如
```bash
face_recognition_0.png # 0,1 同一个人
face_recognition_1.png
face_recognition_2.png # 0,2 不同的人
```
## 4. 安装FastDeploy
使用如下命令安装FastDeploy注意到此处安装的是`vision-cpu`,也可根据需求安装`vision-gpu`
```bash
# 安装fastdeploy-python工具
pip install fastdeploy-python
# 安装vision-cpu模块
fastdeploy install vision-cpu
```
## 5. Python部署
执行如下代码即会自动下载ArcFace模型和测试图片
```bash
python arcface.py
```
执行完成后会输出检测结果如下
```
FaceRecognitionResult: [Dim(512), Min(-0.141219), Max(0.121645), Mean(-0.003172)]
FaceRecognitionResult: [Dim(512), Min(-0.117939), Max(0.141897), Mean(0.000407)]
FaceRecognitionResult: [Dim(512), Min(-0.124471), Max(0.112567), Mean(-0.001320)]
Cosine 01: 0.7211584683376316
Cosine 02: -0.06262668682788906
```
## 6. 其它文档
- [C++部署](./cpp/README.md)
- [ArcFace API文档](./api.md)

View File

@@ -0,0 +1,113 @@
# ArcFace API说明
## 0. 特别说明
fastdeploy支持 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition) 的人脸识别模块recognition中大部分模型的部署包括ArcFace、CosFace、Partial FC、VPL等由于用法类似这里仅用ArcFace来说明参数设置。
## 1. Python API
### 1.1 ArcFace 类
#### 1.1.1 类初始化说明
```python
fastdeploy.vision.deepinsight.ArcFace(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
```
ArcFace模型加载和初始化当model_format为`fd.Frontend.ONNX`只需提供model_file`xxx.onnx`当model_format为`fd.Frontend.PADDLE`则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### 1.1.2 predict函数
> ```python
> ArcFace.predict(image_data)
> ```
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
示例代码参考[arcface.py](./arcface.py)
### 1.2 其他支持的类
```python
fastdeploy.vision.deepinsight.ArcFace(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
fastdeploy.vision.deepinsight.CosFace(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
fastdeploy.vision.deepinsight.PartialFC(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
fastdeploy.vision.deepinsight.VPL(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
fastdeploy.vision.deepinsight.InsightFaceRecognitionModel(model_file, params_file=None, runtime_option=None, model_format=fd.Frontend.ONNX)
```
Tips: 如果 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition) 人脸识别的推理逻辑没有随它自身的版本发生太大变化,则可以都统一使用 InsightFaceRecognitionModel 进行推理。
## 2. C++ API
### 2.1 ArcFace 类
#### 2.1.1 类初始化说明
```C++
fastdeploy::vision::deepinsight::ArcFace(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX)
```
ArcFace模型加载和初始化当model_format为`Frontend::ONNX`时只需提供model_file如`xxx.onnx`当model_format为`Frontend::PADDLE`时则需同时提供model_file和params_file。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(Frontend): 模型格式
#### 2.1.2 Predict函数
> ```C++
> ArcFace::Predict(cv::Mat* im, FaceRecognitionResult* result)
> ```
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果result的成员embedding包含人脸向量
示例代码参考[cpp/arcface.cc](cpp/arcface.cc)
### 2.2 其他支持的类
```C++
fastdeploy::vision::deepinsight::ArcFace(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
fastdeploy::vision::deepinsight::CosFace(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
fastdeploy::vision::deepinsight::PartialFC(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
fastdeploy::vision::deepinsight::VPL(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
fastdeploy::vision::deepinsight::InsightFaceRecognitionModel(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const Frontend& model_format = Frontend::ONNX);
```
Tips: 如果 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition) 人脸识别的推理逻辑没有随它自身的版本发生太大变化,则可以都统一使用 InsightFaceRecognitionModel 进行推理。
## 3. 其它API使用
- [模型部署RuntimeOption配置](../../../docs/api/runtime_option.md)

View File

@@ -0,0 +1,46 @@
import fastdeploy as fd
import numpy as np
import cv2
# 余弦相似度
def cosine_similarity(a, b):
a = np.array(a)
b = np.array(b)
mul_a = np.linalg.norm(a, ord=2)
mul_b = np.linalg.norm(b, ord=2)
mul_ab = np.dot(a, b)
return mul_ab / (np.sqrt(mul_a) * np.sqrt(mul_b))
# 加载模型
model = fd.vision.deepinsight.ArcFace("ms1mv3_arcface_r100.onnx")
print("Initialed model!")
# 加载图片
face0 = cv2.imread("face_recognition_0.png") # 0,1 同一个人
face1 = cv2.imread("face_recognition_1.png")
face2 = cv2.imread("face_recognition_2.png") # 0,2 不同的人
# 设置 l2 normalize
model.l2_normalize = True
result0 = model.predict(face0)
result1 = model.predict(face1)
result2 = model.predict(face2)
# 计算余弦相似度
embedding0 = result0.embedding
embedding1 = result1.embedding
embedding2 = result2.embedding
cosine01 = cosine_similarity(embedding0, embedding1)
cosine02 = cosine_similarity(embedding0, embedding2)
# 打印结果
print(result0, end="")
print(result1, end="")
print(result2, end="")
print("Cosine 01: ", cosine01)
print("Cosine 02: ", cosine02)
print(model.runtime_option)

View File

@@ -0,0 +1,17 @@
PROJECT(arcface_demo C CXX)
CMAKE_MINIMUM_REQUIRED(VERSION 3.16)
# 在低版本ABI环境中通过如下代码进行兼容性编译
# add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
# 指定下载解压后的fastdeploy库路径
set(FASTDEPLOY_INSTALL_DIR ${PROJECT_SOURCE_DIR}/fastdeploy-linux-x64-0.3.0/)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(arcface_demo ${PROJECT_SOURCE_DIR}/arcface.cc)
# 添加FastDeploy库依赖
target_link_libraries(arcface_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,61 @@
# 编译ArcFace示例
## 0. 简介
当前支持模型版本为:[ArcFace CommitID:babb9a5](https://github.com/deepinsight/insightface/commit/babb9a5)
## 1. 下载和解压预测库
```bash
wget https://bj.bcebos.com/paddle2onnx/fastdeploy/fastdeploy-linux-x64-0.3.0.tgz
tar xvf fastdeploy-linux-x64-0.3.0.tgz
```
## 1. 编译示例代码
```bash
mkdir build & cd build
cmake ..
make -j
```
## 3. 特别说明
fastdeploy支持 [insightface](https://github.com/deepinsight/insightface/tree/master/recognition) 的人脸识别模块recognition中大部分模型的部署包括ArcFace、CosFace、Partial FC、VPL等由于用法类似这里仅用ArcFace来演示部署流程。所有支持的模型结构可参考 [ArcFace API文档](../api.md).
## 4. 获取ONNX文件
访问[ArcFace](https://github.com/deepinsight/insightface/tree/master/recognition/arcface_torch)官方github库按照指引下载安装下载pt模型文件利用 `torch2onnx.py` 得到`onnx`格式文件。
* 下载ArcFace模型文件
```
Link: https://pan.baidu.com/share/init?surl=CL-l4zWqsI1oDuEEYVhj-g code: e8pw
```
* 导出onnx格式文件
```bash
PYTHONPATH=. python ./torch2onnx.py partial_fc/pytorch/ms1mv3_arcface_r100_fp16/backbone.pth --output ms1mv3_arcface_r100.onnx --network r100 --simplify 1
```
* 移动onnx文件到model_zoo/arcface的目录
```bash
cp PATH/TO/ms1mv3_arcface_r100.onnx PATH/TO/model_zoo/vision/arcface/
```
## 5. 准备测试图片
准备3张仅包含人脸的测试图片命名为face_recognition_*.jpg并拷贝到可执行文件所在的目录比如
```bash
face_recognition_0.png # 0,1 同一个人
face_recognition_1.png
face_recognition_2.png # 0,2 不同的人
```
## 6. 执行
```bash
./arcface_demo
```
执行完成后会输出检测结果如下
```
FaceRecognitionResult: [Dim(512), Min(-0.141219), Max(0.121645), Mean(-0.003172)]
FaceRecognitionResult: [Dim(512), Min(-0.117939), Max(0.141897), Mean(0.000407)]
FaceRecognitionResult: [Dim(512), Min(-0.124471), Max(0.112567), Mean(-0.001320)]
Cosine 01: 0.7211584683376316
Cosine 02: -0.06262668682788906
```

View File

@@ -0,0 +1,64 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
int main() {
namespace vis = fastdeploy::vision;
// 0,1 同一个人, 0,2 不同的人
std::string model_file = "./ms1mv3_arcface_r100.onnx";
std::string face0_path = "./face_recognition_0.png";
std::string face1_path = "./face_recognition_1.png";
std::string face2_path = "./face_recognition_2.png";
auto model = vis::deepinsight::ArcFace(model_file);
if (!model.Initialized()) {
std::cerr << "Init Failed! Model: " << model_file << std::endl;
return -1;
} else {
std::cout << "Init Done! Model:" << model_file << std::endl;
}
model.EnableDebug();
// 设置输出l2 normalize后的embedding
model.l2_normalize = true;
cv::Mat face0 = cv::imread(face0_path);
cv::Mat face1 = cv::imread(face1_path);
cv::Mat face2 = cv::imread(face2_path);
vis::FaceRecognitionResult res0;
vis::FaceRecognitionResult res1;
vis::FaceRecognitionResult res2;
if ((!model.Predict(&face0, &res0)) || (!model.Predict(&face1, &res1)) ||
(!model.Predict(&face2, &res2))) {
std::cerr << "Prediction Failed." << std::endl;
return -1;
}
std::cout << "Prediction Done!" << std::endl;
// 输出预测框结果
std::cout << "--- [Face 0]:" << res0.Str();
std::cout << "--- [Face 1]:" << res1.Str();
std::cout << "--- [Face 2]:" << res2.Str();
// 计算余弦相似度
float cosine01 = vis::utils::CosineSimilarity(res0.embedding, res1.embedding,
model.l2_normalize);
float cosine02 = vis::utils::CosineSimilarity(res0.embedding, res2.embedding,
model.l2_normalize);
std::cout << "Detect Done! Cosine 01: " << cosine01
<< ", Cosine 02:" << cosine02 << std::endl;
return 0;
}