Add YOLOv5-cls Model (#335)

* add yolov5cls

* fixed bugs

* fixed bugs

* fixed preprocess bug

* add yolov5cls readme

* deal with comments

* Add YOLOv5Cls Note

* add yolov5cls test

Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
WJJ1995
2022-10-12 15:57:26 +08:00
committed by GitHub
parent 945e197bd1
commit b557dbc2d8
19 changed files with 719 additions and 4 deletions

View File

@@ -55,7 +55,7 @@ PaddleClas模型加载和初始化其中model_file, params_file为训练模
> PaddleClasModel.predict(input_image, topk=1)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
> 模型预测结口,输入图像直接输出分类topk结果。
>
> **参数**
>

View File

@@ -0,0 +1,29 @@
# YOLOv5Cls准备部署模型
- YOLOv5Cls v6.2部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v6.2),和[基于ImageNet的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v6.2)
- 1[官方库](https://github.com/ultralytics/yolov5/releases/tag/v6.2)提供的*-cls.pt模型使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后可直接进行部署
- 2开发者基于自己数据训练的YOLOv5Cls v6.2模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后完成部署。
## 下载预训练ONNX模型
为了方便开发者的测试下面提供了YOLOv5Cls导出的各系列模型开发者可直接下载使用。下表中模型的精度来源于源官方库
| 模型 | 大小 | 精度(top1) | 精度(top5) |
|:---------------------------------------------------------------- |:----- |:----- |:----- |
| [YOLOv5n-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx) | 9.6MB | 64.6% | 85.4% |
| [YOLOv5s-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5s-cls.onnx) | 21MB | 71.5% | 90.2% |
| [YOLOv5m-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5m-cls.onnx) | 50MB | 75.9% | 92.9% |
| [YOLOv5l-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5l-cls.onnx) | 102MB | 78.0% | 94.0% |
| [YOLOv5x-cls](https://bj.bcebos.com/paddlehub/fastdeploy/yolov5x-cls.onnx) | 184MB | 79.0% | 94.4% |
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)
## 版本说明
- 本版本文档和代码基于[YOLOv5 v6.2](https://github.com/ultralytics/yolov5/tree/v6.2) 编写

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,89 @@
# YOLOv5Cls C++部署示例
本目录下提供`infer.cc`快速完成YOLOv5Cls在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/quick_start)
以Linux上CPU推理为例在本目录执行如下命令即可完成编译测试
```bash
mkdir build
cd build
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-0.2.1.tgz
tar xvf fastdeploy-linux-x64-0.2.1.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.1
make -j
#下载官方转换好的yolov5模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx
wget hhttps://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# CPU推理
./infer_demo yolov5n-cls.onnx 000000014439.jpg 0
# GPU推理
./infer_demo yolov5n-cls.onnx 000000014439.jpg 1
# GPU上TensorRT推理
./infer_demo yolov5n-cls.onnx 000000014439.jpg 2
```
运行完成后返回结果如下所示
```bash
ClassifyResult(
label_ids: 265,
scores: 0.196327,
)
```
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/compile/how_to_use_sdk_on_windows.md)
## YOLOv5Cls C++接口
### YOLOv5Cls类
```c++
fastdeploy::vision::classification::YOLOv5Cls(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX)
```
YOLOv5Cls模型加载和初始化其中model_file为导出的ONNX模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX时此参数传入空字符串即可
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为ONNX格式
#### Predict函数
> ```c++
> YOLOv5Cls::Predict(cv::Mat* im, int topk = 1)
> ```
>
> 模型预测接口输入图像直接输出输出分类topk结果。
>
> **参数**
>
> > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **topk**(int):返回预测概率最高的topk个分类结果默认为1
> **返回**
>
> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [YOLOv5Cls 模型介绍](..)
- [YOLOv5Cls Python部署](../python)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)

View File

@@ -0,0 +1,104 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_file, const std::string& image_file) {
auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// print res
std::cout << res.Str() << std::endl;
}
void GpuInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// print res
std::cout << res.Str() << std::endl;
}
void TrtInfer(const std::string& model_file, const std::string& image_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("images", {1, 3, 224, 224});
auto model = fastdeploy::vision::classification::YOLOv5Cls(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::ClassifyResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// print res
std::cout << res.Str() << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/model path/to/image run_option, "
"e.g ./infer_model ./yolov5n-cls.onnx ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 2) {
TrtInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -0,0 +1,73 @@
# YOLOv5Cls Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/environment.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/quick_start)
本目录下提供`infer.py`快速完成YOLOv5Cls在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/classification/yolov5cls/python/
#下载yolov5cls模型文件和测试图片
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.onnx
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# CPU推理
python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device cpu --topk 1
# GPU推理
python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --topk 1
# GPU上使用TensorRT推理
python infer.py --model yolov5n-cls.onnx --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True
```
运行完成后返回结果如下所示
```bash
ClassifyResult(
label_ids: 265,
scores: 0.196327,
)
```
## YOLOv5Cls Python接口
```python
fastdeploy.vision.classification.YOLOv5Cls(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX)
```
YOLOv5Cls模型加载和初始化其中model_file为导出的ONNX模型格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为ONNX
### predict函数
> ```python
> YOLOv5Cls.predict(image_data, topk=1)
> ```
>
> 模型预测结口输入图像直接输出分类topk结果。
>
> **参数**
>
> > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **topk**(int):返回预测概率最高的topk个分类结果默认为1
> **返回**
>
> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [YOLOv5Cls 模型介绍](..)
- [YOLOv5Cls C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/runtime/how_to_change_backend.md)

View File

@@ -0,0 +1,51 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of YOLOv5Cls model.")
parser.add_argument(
"--image", type=str, required=True, help="Path of test image file.")
parser.add_argument(
"--topk", type=int, default=1, help="Return topk results.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("images", [1, 3, 224, 224])
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.classification.YOLOv5Cls(
args.model, runtime_option=runtime_option)
# 预测图片分类结果
im = cv2.imread(args.image)
result = model.predict(im.copy(), args.topk)
print(result)

View File

@@ -2,7 +2,7 @@
- YOLOv5 v6.0部署模型实现来自[YOLOv5](https://github.com/ultralytics/yolov5/tree/v6.0),和[基于COCO的预训练模型](https://github.com/ultralytics/yolov5/releases/tag/v6.0)
- 1[官方库](https://github.com/ultralytics/yolov5/releases/tag/v6.0)提供的*.onnx可直接进行部署
- 2开发者基于自己数据训练的YOLOv5 v6.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后,完成部署。
- 2开发者基于自己数据训练的YOLOv5 v6.0模型,可使用[YOLOv5](https://github.com/ultralytics/yolov5)中的`export.py`导出ONNX文件后完成部署。
## 下载预训练ONNX模型

1
fastdeploy/vision.h Normal file → Executable file
View File

@@ -15,6 +15,7 @@
#include "fastdeploy/core/config.h"
#ifdef ENABLE_VISION
#include "fastdeploy/vision/classification/contrib/yolov5cls.h"
#include "fastdeploy/vision/classification/ppcls/model.h"
#include "fastdeploy/vision/detection/contrib/nanodet_plus.h"
#include "fastdeploy/vision/detection/contrib/scaledyolov4.h"

View File

@@ -16,11 +16,13 @@
namespace fastdeploy {
void BindYOLOv5Cls(pybind11::module& m);
void BindPaddleClas(pybind11::module& m);
void BindClassification(pybind11::module& m) {
auto classification_module =
m.def_submodule("classification", "Image classification models.");
BindYOLOv5Cls(classification_module);
BindPaddleClas(classification_module);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,116 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/classification/contrib/yolov5cls.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace classification {
YOLOv5Cls::YOLOv5Cls(const std::string& model_file,
const std::string& params_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool YOLOv5Cls::Initialize() {
// preprocess parameters
size = {224, 224};
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool YOLOv5Cls::Preprocess(Mat* mat, FDTensor* output,
const std::vector<int>& size) {
// CenterCrop
int crop_size = std::min(mat->Height(), mat->Width());
CenterCrop::Run(mat, crop_size, crop_size);
Resize::Run(mat, size[0], size[1], -1, -1, cv::INTER_LINEAR);
// Normalize
BGR2RGB::Run(mat);
std::vector<float> alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
std::vector<float> beta = {0.0f, 0.0f, 0.0f};
Convert::Run(mat, alpha, beta);
std::vector<float> mean = {0.485f, 0.456f, 0.406f};
std::vector<float> std = {0.229f, 0.224f, 0.225f};
Normalize::Run(mat, mean, std, false);
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1);
return true;
}
bool YOLOv5Cls::Postprocess(const FDTensor& infer_result,
ClassifyResult* result, int topk) {
// Softmax
FDTensor infer_result_softmax;
Softmax(infer_result, &infer_result_softmax, 1);
int num_classes = infer_result_softmax.shape[1];
const float* infer_result_buffer =
reinterpret_cast<const float*>(infer_result_softmax.Data());
topk = std::min(num_classes, topk);
result->label_ids =
utils::TopKIndices(infer_result_buffer, num_classes, topk);
result->scores.resize(topk);
for (int i = 0; i < topk; ++i) {
result->scores[i] = *(infer_result_buffer + result->label_ids[i]);
}
return true;
}
bool YOLOv5Cls::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
Mat mat(*im);
std::vector<FDTensor> input_tensors(1);
if (!Preprocess(&mat, &input_tensors[0], size)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
input_tensors[0].name = InputInfoOfRuntime(0).name;
std::vector<FDTensor> output_tensors(1);
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors[0], result, topk)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namespace classification
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,70 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
/** \brief All image classification model APIs are defined inside this namespace
*
*/
namespace classification {
/*! @brief YOLOv5Cls model object used when to load a YOLOv5Cls model exported by YOLOv5
*/
class FASTDEPLOY_DECL YOLOv5Cls : public FastDeployModel {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g yolov5cls/yolov5n-cls.onnx
* \param[in] params_file Path of parameter file, if the model format is ONNX, this parameter will be ignored
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is ONNX format
*/
YOLOv5Cls(const std::string& model_file, const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX);
/// Get model's name
virtual std::string ModelName() const { return "yolov5cls"; }
/** \brief Predict the classification result for an input image
*
* \param[in] im The input image data, comes from cv::imread()
* \param[in] result The output classification result will be writen to this structure
* \param[in] topk Returns the topk classification result with the highest predicted probability, the default is 1
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1);
/// Preprocess image size, the default is (224, 224)
std::vector<int> size;
private:
bool Initialize();
/// Preprocess an input image, and set the preprocessed results to `outputs`
bool Preprocess(Mat* mat, FDTensor* output,
const std::vector<int>& size = {224, 224});
/// Postprocess the inferenced results, and set the final result to `result`
bool Postprocess(const FDTensor& infer_result, ClassifyResult* result,
int topk = 1);
};
} // namespace classification
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,32 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindYOLOv5Cls(pybind11::module& m) {
pybind11::class_<vision::classification::YOLOv5Cls, FastDeployModel>(
m, "YOLOv5Cls")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::classification::YOLOv5Cls& self, pybind11::array& data,
int topk = 1) {
auto mat = PyArrayToCvMat(data);
vision::ClassifyResult res;
self.Predict(&mat, &res, topk);
return res;
})
.def_readwrite("size", &vision::classification::YOLOv5Cls::size);
}
} // namespace fastdeploy

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import absolute_import
from .contrib.yolov5cls import YOLOv5Cls
from .ppcls import PaddleClasModel
PPLCNet = PaddleClasModel

View File

@@ -0,0 +1,15 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class YOLOv5Cls(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=ModelFormat.ONNX):
"""Load a image classification model exported by YOLOv5.
:param model_file: (str)Path of model file, e.g yolov5cls/yolov5n-cls.onnx
:param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX
"""
super(YOLOv5Cls, self).__init__(runtime_option)
assert model_format == ModelFormat.ONNX, "YOLOv5Cls only support model format of ModelFormat.ONNX now."
self._model = C.vision.classification.YOLOv5Cls(
model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "YOLOv5Cls initialize failed."
def predict(self, input_image, topk=1):
"""Classify an input image
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:param topk: (int)The topk result by the classify confidence score, default 1
:return: ClassifyResult
"""
return self._model.predict(input_image, topk)
@property
def size(self):
"""
Returns the preprocess image size
"""
return self._model.size
@size.setter
def size(self, wh):
"""
Set the preprocess image size
"""
assert isinstance(wh, (list, tuple)),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,4 +1,4 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -0,0 +1,49 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
def test_classification_yolov5cls():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/yolov5n-cls.tgz"
input_url = "https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg"
fd.download_and_decompress(model_url, ".")
fd.download(input_url, ".")
model_path = "yolov5n-cls/yolov5n-cls.onnx"
# use ORT
runtime_option = fd.RuntimeOption()
runtime_option.use_ort_backend()
model = fd.vision.classification.YOLOv5Cls(
model_path, runtime_option=runtime_option)
# compare diff
im = cv2.imread("./ILSVRC2012_val_00000010.jpeg")
result = model.predict(im.copy(), topk=5)
with open("yolov5n-cls/result.pkl", "rb") as f:
expect = pickle.load(f)
diff_label = np.fabs(
np.array(result.label_ids) - np.array(expect["labels"]))
diff_score = np.fabs(np.array(result.scores) - np.array(expect["scores"]))
thres = 1e-05
assert diff_label.max(
) < thres, "The label diff is %f, which is bigger than %f" % (
diff_label.max(), thres)
assert diff_score.max(
) < thres, "The score diff is %f, which is bigger than %f" % (
diff_score.max(), thres)