From b557dbc2d8ea0df6930da3c9f2c61c20ff33960c Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Wed, 12 Oct 2022 15:57:26 +0800 Subject: [PATCH] Add YOLOv5-cls Model (#335) * add yolov5cls * fixed bugs * fixed bugs * fixed preprocess bug * add yolov5cls readme * deal with comments * Add YOLOv5Cls Note * add yolov5cls test Co-authored-by: Jason --- .../paddleclas/python/README.md | 2 +- .../vision/classification/yolov5cls/README.md | 29 +++++ .../yolov5cls/cpp/CMakeLists.txt | 14 +++ .../classification/yolov5cls/cpp/README.md | 89 ++++++++++++++ .../classification/yolov5cls/cpp/infer.cc | 104 ++++++++++++++++ .../classification/yolov5cls/python/README.md | 73 +++++++++++ .../classification/yolov5cls/python/infer.py | 51 ++++++++ examples/vision/detection/yolov5/README.md | 2 +- fastdeploy/vision.h | 1 + .../classification/classification_pybind.cc | 2 + .../classification/contrib/yolov5cls.cc | 116 ++++++++++++++++++ .../vision/classification/contrib/yolov5cls.h | 70 +++++++++++ .../contrib/yolov5cls_pybind.cc | 32 +++++ .../vision/classification/__init__.py | 1 + .../vision/classification/contrib/__init__.py | 15 +++ .../classification/contrib/yolov5cls.py | 69 +++++++++++ tests/eval_example/test_ppmatting.py | 2 +- tests/eval_example/test_quantize_diff.py | 2 +- tests/eval_example/test_yolov5cls.py | 49 ++++++++ 19 files changed, 719 insertions(+), 4 deletions(-) create mode 100644 examples/vision/classification/yolov5cls/README.md create mode 100644 examples/vision/classification/yolov5cls/cpp/CMakeLists.txt create mode 100644 examples/vision/classification/yolov5cls/cpp/README.md create mode 100644 examples/vision/classification/yolov5cls/cpp/infer.cc create mode 100644 examples/vision/classification/yolov5cls/python/README.md create mode 100644 examples/vision/classification/yolov5cls/python/infer.py mode change 100644 => 100755 fastdeploy/vision.h create mode 100755 fastdeploy/vision/classification/contrib/yolov5cls.cc create mode 100755 fastdeploy/vision/classification/contrib/yolov5cls.h create mode 100755 fastdeploy/vision/classification/contrib/yolov5cls_pybind.cc create mode 100644 python/fastdeploy/vision/classification/contrib/__init__.py create mode 100644 python/fastdeploy/vision/classification/contrib/yolov5cls.py create mode 100755 tests/eval_example/test_yolov5cls.py diff --git a/examples/vision/classification/paddleclas/python/README.md b/examples/vision/classification/paddleclas/python/README.md index 0d51afc1b..3bb8cb355 100644 --- a/examples/vision/classification/paddleclas/python/README.md +++ b/examples/vision/classification/paddleclas/python/README.md @@ -55,7 +55,7 @@ PaddleClas模型加载和初始化,其中model_file, params_file为训练模 > PaddleClasModel.predict(input_image, topk=1) > ``` > -> 模型预测结口,输入图像直接输出检测结果。 +> 模型预测结口,输入图像直接输出分类topk结果。 > > **参数** > diff --git a/examples/vision/classification/yolov5cls/README.md b/examples/vision/classification/yolov5cls/README.md new file mode 100644 index 000000000..9ed02b728 --- /dev/null +++ b/examples/vision/classification/yolov5cls/README.md @@ -0,0 +1,29 @@ +# YOLOv5Cls准备部署模型 + +- YOLOv5Cls v6.2部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v6.2),和[基于ImageNet的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v6.2) + - (1)[官方库](https://github.com/ultralytics/yolov5/releases/tag/v6.2)提供的*-cls.pt模型,使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,可直接进行部署; + - (2)开发者基于自己数据训练的YOLOv5Cls v6.2模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,完成部署。 + + +## 下载预训练ONNX模型 + +为了方便开发者的测试,下面提供了YOLOv5Cls导出的各系列模型,开发者可直接下载使用。(下表中模型的精度来源于源官方库) +| 模型 | 大小 | 精度(top1) | 精度(top5) | +|:---------------------------------------------------------------- |:----- |:----- |:----- | +| [YOLOv5n-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx) | 9.6MB | 64.6% | 85.4% | +| [YOLOv5s-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-cls.onnx) | 21MB | 71.5% | 90.2% | +| [YOLOv5m-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5m-cls.onnx) | 50MB | 75.9% | 92.9% | +| [YOLOv5l-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5l-cls.onnx) | 102MB | 78.0% | 94.0% | +| [YOLOv5x-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5x-cls.onnx) | 184MB | 79.0% | 94.4% | + + + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) + +## 版本说明 + +- 本版本文档和代码基于[YOLOv5 v6.2](https://github.com/ultralytics/yolov5/tree/v6.2) 编写 diff --git a/examples/vision/classification/yolov5cls/cpp/CMakeLists.txt b/examples/vision/classification/yolov5cls/cpp/CMakeLists.txt new file mode 100644 index 000000000..fea1a2888 --- /dev/null +++ b/examples/vision/classification/yolov5cls/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.12) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/classification/yolov5cls/cpp/README.md b/examples/vision/classification/yolov5cls/cpp/README.md new file mode 100644 index 000000000..79f336845 --- /dev/null +++ b/examples/vision/classification/yolov5cls/cpp/README.md @@ -0,0 +1,89 @@ +# YOLOv5Cls C++部署示例 + +本目录下提供`infer.cc`快速完成YOLOv5Cls在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/quick_start) + +以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试 + +```bash +mkdir build +cd build +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-0.2.1.tgz +tar xvf fastdeploy-linux-x64-0.2.1.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.1 +make -j + +#下载官方转换好的yolov5模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx +wget hhttps://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + + +# CPU推理 +./infer_demo yolov5n-cls.onnx 000000014439.jpg 0 +# GPU推理 +./infer_demo yolov5n-cls.onnx 000000014439.jpg 1 +# GPU上TensorRT推理 +./infer_demo yolov5n-cls.onnx 000000014439.jpg 2 +``` + +运行完成后返回结果如下所示 +```bash +ClassifyResult( +label_ids: 265, +scores: 0.196327, +) +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md) + +## YOLOv5Cls C++接口 + +### YOLOv5Cls类 + +```c++ +fastdeploy::vision::classification::YOLOv5Cls( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX) +``` + +YOLOv5Cls模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ```c++ +> YOLOv5Cls::Predict(cv::Mat* im, int topk = 1) +> ``` +> +> 模型预测接口,输入图像直接输出输出分类topk结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + + +> **返回** +> +> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [YOLOv5Cls 模型介绍](..) +- [YOLOv5Cls Python部署](../python) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md) diff --git a/examples/vision/classification/yolov5cls/cpp/infer.cc b/examples/vision/classification/yolov5cls/cpp/infer.cc new file mode 100644 index 000000000..2920c95b0 --- /dev/null +++ b/examples/vision/classification/yolov5cls/cpp/infer.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // print res + std::cout << res.Str() << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // print res + std::cout << res.Str() << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 224, 224}); + auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::ClassifyResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // print res + std::cout << res.Str() << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./yolov5n-cls.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/classification/yolov5cls/python/README.md b/examples/vision/classification/yolov5cls/python/README.md new file mode 100644 index 000000000..b65021ba3 --- /dev/null +++ b/examples/vision/classification/yolov5cls/python/README.md @@ -0,0 +1,73 @@ +# YOLOv5Cls Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start) + +本目录下提供`infer.py`快速完成YOLOv5Cls在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd examples/vision/classification/yolov5cls/python/ + +#下载yolov5cls模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx +wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg + +# CPU推理 +python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1 +# GPU推理 +python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1 +# GPU上使用TensorRT推理 +python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True +``` + +运行完成后返回结果如下所示 +```bash +ClassifyResult( +label_ids: 265, +scores: 0.196327, +) +``` + +## YOLOv5Cls Python接口 + +```python +fastdeploy.vision.classification.YOLOv5Cls(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX) +``` + +YOLOv5Cls模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX + +### predict函数 + +> ```python +> YOLOv5Cls.predict(image_data, topk=1) +> ``` +> +> 模型预测结口,输入图像直接输出分类topk结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **topk**(int):返回预测概率最高的topk个分类结果,默认为1 + +> **返回** +> +> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [YOLOv5Cls 模型介绍](..) +- [YOLOv5Cls C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md) diff --git a/examples/vision/classification/yolov5cls/python/infer.py b/examples/vision/classification/yolov5cls/python/infer.py new file mode 100644 index 000000000..576db32f2 --- /dev/null +++ b/examples/vision/classification/yolov5cls/python/infer.py @@ -0,0 +1,51 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of YOLOv5Cls model.") + parser.add_argument( + "--image", type=str, required=True, help="Path of test image file.") + parser.add_argument( + "--topk", type=int, default=1, help="Return topk results.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("images", [1, 3, 224, 224]) + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.classification.YOLOv5Cls( + args.model, runtime_option=runtime_option) + +# 预测图片分类结果 +im = cv2.imread(args.image) +result = model.predict(im.copy(), args.topk) +print(result) diff --git a/examples/vision/detection/yolov5/README.md b/examples/vision/detection/yolov5/README.md index 302e77a7b..222cb53f7 100644 --- a/examples/vision/detection/yolov5/README.md +++ b/examples/vision/detection/yolov5/README.md @@ -2,7 +2,7 @@ - YOLOv5 v6.0部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v6.0),和[基于COCO的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v6.0) - (1)[官方库](https://github.com/ultralytics/yolov5/releases/tag/v6.0)提供的*.onnx可直接进行部署; - - (2)开发者基于自己数据训练的YOLOv5 v6.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后后,完成部署。 + - (2)开发者基于自己数据训练的YOLOv5 v6.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,完成部署。 ## 下载预训练ONNX模型 diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h old mode 100644 new mode 100755 index e7590f828..8e1358d8a --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -15,6 +15,7 @@ #include "fastdeploy/core/config.h" #ifdef ENABLE_VISION +#include "fastdeploy/vision/classification/contrib/yolov5cls.h" #include "fastdeploy/vision/classification/ppcls/model.h" #include "fastdeploy/vision/detection/contrib/nanodet_plus.h" #include "fastdeploy/vision/detection/contrib/scaledyolov4.h" diff --git a/fastdeploy/vision/classification/classification_pybind.cc b/fastdeploy/vision/classification/classification_pybind.cc index fe64a1996..497d692c3 100644 --- a/fastdeploy/vision/classification/classification_pybind.cc +++ b/fastdeploy/vision/classification/classification_pybind.cc @@ -16,11 +16,13 @@ namespace fastdeploy { +void BindYOLOv5Cls(pybind11::module& m); void BindPaddleClas(pybind11::module& m); void BindClassification(pybind11::module& m) { auto classification_module = m.def_submodule("classification", "Image classification models."); + BindYOLOv5Cls(classification_module); BindPaddleClas(classification_module); } } // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/yolov5cls.cc b/fastdeploy/vision/classification/contrib/yolov5cls.cc new file mode 100755 index 000000000..59f6740be --- /dev/null +++ b/fastdeploy/vision/classification/contrib/yolov5cls.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/classification/contrib/yolov5cls.h" + +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { +namespace vision { +namespace classification { + +YOLOv5Cls::YOLOv5Cls(const std::string& model_file, + const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool YOLOv5Cls::Initialize() { + // preprocess parameters + size = {224, 224}; + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool YOLOv5Cls::Preprocess(Mat* mat, FDTensor* output, + const std::vector& size) { + // CenterCrop + int crop_size = std::min(mat->Height(), mat->Width()); + CenterCrop::Run(mat, crop_size, crop_size); + Resize::Run(mat, size[0], size[1], -1, -1, cv::INTER_LINEAR); + // Normalize + BGR2RGB::Run(mat); + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + Convert::Run(mat, alpha, beta); + std::vector mean = {0.485f, 0.456f, 0.406f}; + std::vector std = {0.229f, 0.224f, 0.225f}; + Normalize::Run(mat, mean, std, false); + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); + return true; +} + +bool YOLOv5Cls::Postprocess(const FDTensor& infer_result, + ClassifyResult* result, int topk) { + // Softmax + FDTensor infer_result_softmax; + Softmax(infer_result, &infer_result_softmax, 1); + int num_classes = infer_result_softmax.shape[1]; + const float* infer_result_buffer = + reinterpret_cast(infer_result_softmax.Data()); + topk = std::min(num_classes, topk); + result->label_ids = + utils::TopKIndices(infer_result_buffer, num_classes, topk); + result->scores.resize(topk); + for (int i = 0; i < topk; ++i) { + result->scores[i] = *(infer_result_buffer + result->label_ids[i]); + } + return true; +} + +bool YOLOv5Cls::Predict(cv::Mat* im, ClassifyResult* result, int topk) { + Mat mat(*im); + std::vector input_tensors(1); + if (!Preprocess(&mat, &input_tensors[0], size)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + + input_tensors[0].name = InputInfoOfRuntime(0).name; + std::vector output_tensors(1); + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } + + if (!Postprocess(output_tensors[0], result, topk)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + return true; +} + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/yolov5cls.h b/fastdeploy/vision/classification/contrib/yolov5cls.h new file mode 100755 index 000000000..1e2ff3f99 --- /dev/null +++ b/fastdeploy/vision/classification/contrib/yolov5cls.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { +namespace vision { +/** \brief All image classification model APIs are defined inside this namespace + * + */ +namespace classification { + +/*! @brief YOLOv5Cls model object used when to load a YOLOv5Cls model exported by YOLOv5 + */ +class FASTDEPLOY_DECL YOLOv5Cls : public FastDeployModel { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g yolov5cls/yolov5n-cls.onnx + * \param[in] params_file Path of parameter file, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + YOLOv5Cls(const std::string& model_file, const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + /// Get model's name + virtual std::string ModelName() const { return "yolov5cls"; } + + /** \brief Predict the classification result for an input image + * + * \param[in] im The input image data, comes from cv::imread() + * \param[in] result The output classification result will be writen to this structure + * \param[in] topk Returns the topk classification result with the highest predicted probability, the default is 1 + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1); + + /// Preprocess image size, the default is (224, 224) + std::vector size; + + private: + bool Initialize(); + /// Preprocess an input image, and set the preprocessed results to `outputs` + bool Preprocess(Mat* mat, FDTensor* output, + const std::vector& size = {224, 224}); + + /// Postprocess the inferenced results, and set the final result to `result` + bool Postprocess(const FDTensor& infer_result, ClassifyResult* result, + int topk = 1); +}; + +} // namespace classification +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/classification/contrib/yolov5cls_pybind.cc b/fastdeploy/vision/classification/contrib/yolov5cls_pybind.cc new file mode 100755 index 000000000..5a42dec38 --- /dev/null +++ b/fastdeploy/vision/classification/contrib/yolov5cls_pybind.cc @@ -0,0 +1,32 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindYOLOv5Cls(pybind11::module& m) { + pybind11::class_( + m, "YOLOv5Cls") + .def(pybind11::init()) + .def("predict", + [](vision::classification::YOLOv5Cls& self, pybind11::array& data, + int topk = 1) { + auto mat = PyArrayToCvMat(data); + vision::ClassifyResult res; + self.Predict(&mat, &res, topk); + return res; + }) + .def_readwrite("size", &vision::classification::YOLOv5Cls::size); +} +} // namespace fastdeploy diff --git a/python/fastdeploy/vision/classification/__init__.py b/python/fastdeploy/vision/classification/__init__.py index 3ebcf9c2f..ceeaa024a 100644 --- a/python/fastdeploy/vision/classification/__init__.py +++ b/python/fastdeploy/vision/classification/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import absolute_import +from .contrib.yolov5cls import YOLOv5Cls from .ppcls import PaddleClasModel PPLCNet = PaddleClasModel diff --git a/python/fastdeploy/vision/classification/contrib/__init__.py b/python/fastdeploy/vision/classification/contrib/__init__.py new file mode 100644 index 000000000..8034e10bf --- /dev/null +++ b/python/fastdeploy/vision/classification/contrib/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import diff --git a/python/fastdeploy/vision/classification/contrib/yolov5cls.py b/python/fastdeploy/vision/classification/contrib/yolov5cls.py new file mode 100644 index 000000000..8a4744e56 --- /dev/null +++ b/python/fastdeploy/vision/classification/contrib/yolov5cls.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class YOLOv5Cls(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a image classification model exported by YOLOv5. + + :param model_file: (str)Path of model file, e.g yolov5cls/yolov5n-cls.onnx + :param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX + """ + + super(YOLOv5Cls, self).__init__(runtime_option) + + assert model_format == ModelFormat.ONNX, "YOLOv5Cls only support model format of ModelFormat.ONNX now." + self._model = C.vision.classification.YOLOv5Cls( + model_file, params_file, self._runtime_option, model_format) + assert self.initialized, "YOLOv5Cls initialize failed." + + def predict(self, input_image, topk=1): + """Classify an input image + + :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param topk: (int)The topk result by the classify confidence score, default 1 + :return: ClassifyResult + """ + + return self._model.predict(input_image, topk) + + @property + def size(self): + """ + Returns the preprocess image size + """ + return self._model.size + + @size.setter + def size(self, wh): + """ + Set the preprocess image size + """ + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh diff --git a/tests/eval_example/test_ppmatting.py b/tests/eval_example/test_ppmatting.py index 190f3017f..f26fd358f 100644 --- a/tests/eval_example/test_ppmatting.py +++ b/tests/eval_example/test_ppmatting.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/eval_example/test_quantize_diff.py b/tests/eval_example/test_quantize_diff.py index 2c9454dd3..8bc7b396a 100755 --- a/tests/eval_example/test_quantize_diff.py +++ b/tests/eval_example/test_quantize_diff.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/eval_example/test_yolov5cls.py b/tests/eval_example/test_yolov5cls.py new file mode 100755 index 000000000..50eefa36c --- /dev/null +++ b/tests/eval_example/test_yolov5cls.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np + + +def test_classification_yolov5cls(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.tgz" + input_url = "https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg" + fd.download_and_decompress(model_url, ".") + fd.download(input_url, ".") + model_path = "yolov5n-cls/yolov5n-cls.onnx" + # use ORT + runtime_option = fd.RuntimeOption() + runtime_option.use_ort_backend() + model = fd.vision.classification.YOLOv5Cls( + model_path, runtime_option=runtime_option) + + # compare diff + im = cv2.imread("./ILSVRC2012_val_00000010.jpeg") + result = model.predict(im.copy(), topk=5) + with open("yolov5n-cls/result.pkl", "rb") as f: + expect = pickle.load(f) + + diff_label = np.fabs( + np.array(result.label_ids) - np.array(expect["labels"])) + diff_score = np.fabs(np.array(result.scores) - np.array(expect["scores"])) + thres = 1e-05 + assert diff_label.max( + ) < thres, "The label diff is %f, which is bigger than %f" % ( + diff_label.max(), thres) + assert diff_score.max( + ) < thres, "The score diff is %f, which is bigger than %f" % ( + diff_score.max(), thres)