mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Update evaluation function to support calculate average inference time (#106)
* Update README.md * Update README.md * Update README.md * Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Add evaluation calculate time and fix some bugs * Update classification __init__ * Move to ppseg Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -58,8 +58,6 @@ void OrtBackend::BuildOption(const OrtBackendOption& option) {
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
option_.use_gpu = false;
|
option_.use_gpu = false;
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(option.gpu_id == 0, "Requires gpu_id == 0, but now gpu_id = " +
|
|
||||||
std::to_string(option.gpu_id) + ".");
|
|
||||||
OrtCUDAProviderOptions cuda_options;
|
OrtCUDAProviderOptions cuda_options;
|
||||||
cuda_options.device_id = option.gpu_id;
|
cuda_options.device_id = option.gpu_id;
|
||||||
session_options_.AppendExecutionProvider_CUDA(cuda_options);
|
session_options_.AppendExecutionProvider_CUDA(cuda_options);
|
||||||
|
@@ -20,12 +20,14 @@ namespace fastdeploy {
|
|||||||
namespace vision {
|
namespace vision {
|
||||||
namespace classification {
|
namespace classification {
|
||||||
|
|
||||||
PaddleClasModel::PaddleClasModel(const std::string& model_file, const std::string& params_file,
|
PaddleClasModel::PaddleClasModel(const std::string& model_file,
|
||||||
const std::string& config_file, const RuntimeOption& custom_option,
|
const std::string& params_file,
|
||||||
const Frontend& model_format) {
|
const std::string& config_file,
|
||||||
|
const RuntimeOption& custom_option,
|
||||||
|
const Frontend& model_format) {
|
||||||
config_file_ = config_file;
|
config_file_ = config_file;
|
||||||
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
|
valid_cpu_backends = {Backend::ORT, Backend::PDINFER};
|
||||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER};
|
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||||
runtime_option = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option.model_file = model_file;
|
runtime_option.model_file = model_file;
|
||||||
@@ -109,8 +111,8 @@ bool PaddleClasModel::Preprocess(Mat* mat, FDTensor* output) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool PaddleClasModel::Postprocess(const FDTensor& infer_result, ClassifyResult* result,
|
bool PaddleClasModel::Postprocess(const FDTensor& infer_result,
|
||||||
int topk) {
|
ClassifyResult* result, int topk) {
|
||||||
int num_classes = infer_result.shape[1];
|
int num_classes = infer_result.shape[1];
|
||||||
const float* infer_result_buffer =
|
const float* infer_result_buffer =
|
||||||
reinterpret_cast<const float*>(infer_result.data.data());
|
reinterpret_cast<const float*>(infer_result.data.data());
|
||||||
@@ -148,6 +150,6 @@ bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace classification
|
} // namespace classification
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -12,11 +12,11 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/segmentation/ppseg/model.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace vision {
|
namespace vision {
|
||||||
namespace utils {
|
namespace segmentation {
|
||||||
|
|
||||||
void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
|
void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
|
||||||
bool contain_score_map) {
|
bool contain_score_map) {
|
||||||
@@ -54,6 +54,6 @@ void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace utils
|
} // namespace segmentation
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -143,8 +143,7 @@ bool PaddleSegModel::Postprocess(
|
|||||||
Mat* mat = nullptr;
|
Mat* mat = nullptr;
|
||||||
if (is_resized) {
|
if (is_resized) {
|
||||||
cv::Mat temp_mat;
|
cv::Mat temp_mat;
|
||||||
utils::FDTensor2FP32CVMat(temp_mat, infer_result,
|
FDTensor2FP32CVMat(temp_mat, infer_result, result->contain_score_map);
|
||||||
result->contain_score_map);
|
|
||||||
|
|
||||||
// original image shape
|
// original image shape
|
||||||
auto iter_ipt = (*im_info).find("input_shape");
|
auto iter_ipt = (*im_info).find("input_shape");
|
||||||
|
@@ -38,6 +38,9 @@ class FASTDEPLOY_DECL PaddleSegModel : public FastDeployModel {
|
|||||||
std::vector<std::shared_ptr<Processor>> processors_;
|
std::vector<std::shared_ptr<Processor>> processors_;
|
||||||
std::string config_file_;
|
std::string config_file_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
|
||||||
|
bool contain_score_map);
|
||||||
} // namespace segmentation
|
} // namespace segmentation
|
||||||
} // namespace vision
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -115,9 +115,6 @@ void NCHW2NHWC(FDTensor& infer_result) {
|
|||||||
infer_result.shape = {num, height, width, channel};
|
infer_result.shape = {num, height, width, channel};
|
||||||
}
|
}
|
||||||
|
|
||||||
void FDTensor2FP32CVMat(cv::Mat& mat, FDTensor& infer_result,
|
|
||||||
bool contain_score_map);
|
|
||||||
|
|
||||||
void NMS(DetectionResult* output, float iou_threshold = 0.5);
|
void NMS(DetectionResult* output, float iou_threshold = 0.5);
|
||||||
|
|
||||||
void NMS(FaceDetectionResult* result, float iou_threshold = 0.5);
|
void NMS(FaceDetectionResult* result, float iou_threshold = 0.5);
|
||||||
|
@@ -57,7 +57,6 @@ void BindVision(pybind11::module& m) {
|
|||||||
.def_readwrite("label_map", &vision::SegmentationResult::label_map)
|
.def_readwrite("label_map", &vision::SegmentationResult::label_map)
|
||||||
.def_readwrite("score_map", &vision::SegmentationResult::score_map)
|
.def_readwrite("score_map", &vision::SegmentationResult::score_map)
|
||||||
.def_readwrite("shape", &vision::SegmentationResult::shape)
|
.def_readwrite("shape", &vision::SegmentationResult::shape)
|
||||||
.def_readwrite("shape", &vision::SegmentationResult::shape)
|
|
||||||
.def("__repr__", &vision::SegmentationResult::Str)
|
.def("__repr__", &vision::SegmentationResult::Str)
|
||||||
.def("__str__", &vision::SegmentationResult::Str);
|
.def("__str__", &vision::SegmentationResult::Str);
|
||||||
|
|
||||||
|
@@ -2,25 +2,53 @@
|
|||||||
|
|
||||||
## 模型版本说明
|
## 模型版本说明
|
||||||
|
|
||||||
- [PaddleClas Release/2.4](https://github.com/PaddlePaddle/PaddleClas)
|
- [PaddleClas Release/2.4](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4)
|
||||||
|
|
||||||
|
目前FastDeploy支持如下模型的部署
|
||||||
|
|
||||||
|
- [PP-LCNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNet.md)
|
||||||
|
- [PP-LCNetV2系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNetV2.md)
|
||||||
|
- [EfficientNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/EfficientNet_and_ResNeXt101_wsl.md)
|
||||||
|
- [GhostNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md)
|
||||||
|
- [MobileNet系列模型(包含v1,v2,v3)](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md)
|
||||||
|
- [ShuffleNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Mobile.md)
|
||||||
|
- [SqueezeNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Others.md)
|
||||||
|
- [Inception系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/Inception.md)
|
||||||
|
- [PP-HGNet系列模型](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-HGNet.md)
|
||||||
|
- [ResNet系列模型(包含vd系列)](https://github.com/PaddlePaddle/PaddleClas/blob/develop/docs/zh_CN/models/ResNet_and_vd.md)
|
||||||
|
|
||||||
## 准备PaddleClas部署模型
|
## 准备PaddleClas部署模型
|
||||||
|
|
||||||
PaddleClas模型导出,请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA)
|
PaddleClas模型导出,请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA)
|
||||||
|
|
||||||
注意:PaddleClas导出的模型仅包含`inference.pdmodel`和`inference.pdiparams`两个文档,但为了满足部署的需求,同时也需准备其提供的[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件,FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息,开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数。
|
注意:PaddleClas导出的模型仅包含`inference.pdmodel`和`inference.pdiparams`两个文档,但为了满足部署的需求,同时也需准备其提供的通用[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件,FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息,开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数,具体可比照PaddleClas模型训练[config](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.4/ppcls/configs/ImageNet)中的infer部分的配置信息进行修改。
|
||||||
|
|
||||||
|
|
||||||
## 下载预训练模型
|
## 下载预训练模型
|
||||||
|
|
||||||
为了方便开发者的测试,下面提供了PaddleClas导出的部分模型(含inference_cls.yaml文件),开发者可直接下载使用。
|
为了方便开发者的测试,下面提供了PaddleClas导出的部分模型(含inference_cls.yaml文件),开发者可直接下载使用。
|
||||||
|
|
||||||
| 模型 | 大小 |输入Shape | 精度 |
|
| 模型 | 参数文件大小 |输入Shape | Top1 | Top5 |
|
||||||
|:---------------------------------------------------------------- |:----- |:----- | :----- |
|
|:---------------------------------------------------------------- |:----- |:----- | :----- | :----- |
|
||||||
| [PPLCNet]() | 141MB | 224x224 |51.4% |
|
| [PPLCNet_x1_0](https://bj.bcebos.com/paddlehub/fastdeploy/PPLCNet_x1_0_infer.tgz) | 12MB | 224x224 |71.32% | 90.03% |
|
||||||
| [PPLCNetv2]() | 10MB | 224x224 |51.4% |
|
| [PPLCNetV2_base](https://bj.bcebos.com/paddlehub/fastdeploy/PPLCNetV2_base_infer.tgz) | 26MB | 224x224 |77.04% | 93.27% |
|
||||||
| [EfficientNet]() | | 224x224 | |
|
| [EfficientNetB7](https://bj.bcebos.com/paddlehub/fastdeploy/EfficientNetB7_infer.tgz) | 255MB | 600x600 | 84.3% | 96.9% |
|
||||||
|
| [EfficientNetB0_small](https://bj.bcebos.com/paddlehub/fastdeploy/EfficientNetB0_small_infer.tgz)| 18MB | 224x224 | 75.8% | 75.8% |
|
||||||
|
| [GhostNet_x1_3_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/GhostNet_x1_3_ssld_infer.tgz) | 29MB | 224x224 | 75.7% | 92.5% |
|
||||||
|
| [GhostNet_x0_5_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/GhostNet_x0_5_infer.tgz) | 10MB | 224x224 | 66.8% | 86.9% |
|
||||||
|
| [MobileNetV1_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV1_x0_25_infer.tgz) | 1.9MB | 224x224 | 51.4% | 75.5% |
|
||||||
|
| [MobileNetV1_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV1_ssld_infer.tgz) | 17MB | 224x224 | 77.9% | 93.9% |
|
||||||
|
| [MobileNetV2_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV2_x0_25_infer.tgz) | 5.9MB | 224x224 | 53.2% | 76.5% |
|
||||||
|
| [MobileNetV2_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV2_ssld_infer.tgz) | 14MB | 224x224 | 76.74% | 93.39% |
|
||||||
|
| [MobileNetV3_small_x0_35_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV3_small_x0_35_ssld_infer.tgz) | 6.4MB | 224x224 | 55.55% | 77.71% |
|
||||||
|
| [MobileNetV3_large_x1_0_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/MobileNetV3_large_x1_0_ssld_infer.tgz) | 22MB | 224x224 | 78.96% | 94.48% |
|
||||||
|
| [ShuffleNetV2_x0_25](https://bj.bcebos.com/paddlehub/fastdeploy/ShuffleNetV2_x0_25_infer.tgz) | 2.4MB | 224x224 | 49.9% | 73.79% |
|
||||||
|
| [ShuffleNetV2_x2_0](https://bj.bcebos.com/paddlehub/fastdeploy/ShuffleNetV2_x2_0_infer.tgz) | 29MB | 224x224 | 73.15% | 91.2% |
|
||||||
|
| [SqueezeNet1_1](https://bj.bcebos.com/paddlehub/fastdeploy/SqueezeNet1_1_infer.tgz) | 4.8MB | 224x224 | 60.1% | 81.9% |
|
||||||
|
| [InceptionV3](https://bj.bcebos.com/paddlehub/fastdeploy/InceptionV3_infer.tgz) | 92MB | 299x299 | 79.14% | 94.59% |
|
||||||
|
| [PPHGNet_tiny_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/PPHGNet_tiny_ssld_infer.tgz) | 57MB | 224x224 | 81.95% | 96.12% |
|
||||||
|
| [PPHGNet_base_ssld](https://bj.bcebos.com/paddlehub/fastdeploy/PPHGNet_base_ssld_infer.tgz) | 274MB | 224x224 | 85.0% | 97.35% |
|
||||||
|
| [ResNet50_vd](https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz) | 98MB | 224x224 | 79.12% | 94.44% |
|
||||||
|
|
||||||
## 详细部署文档
|
## 详细部署文档
|
||||||
|
|
||||||
|
@@ -5,67 +5,71 @@
|
|||||||
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md)
|
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md)
|
||||||
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md)
|
- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md)
|
||||||
|
|
||||||
本目录下提供`infer.py`快速完成YOLOv7在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
|
本目录下提供`infer.py`快速完成ResNet50_vd在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
|
||||||
|
|
||||||
```
|
```
|
||||||
# 下载yolov7模型文件和测试图片
|
# 下载ResNet50_vd模型文件和测试图片
|
||||||
wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx
|
wget https://bj.bcebos.com/paddlehub/fastdeploy/ResNet50_vd_infer.tgz
|
||||||
wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
|
tar -xvf ResNet50_vd_infer.tgz
|
||||||
|
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||||
|
|
||||||
|
|
||||||
#下载部署示例代码
|
#下载部署示例代码
|
||||||
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
git clone https://github.com/PaddlePaddle/FastDeploy.git
|
||||||
cd examples/vison/detection/yolov7/python/
|
cd examples/vision/classification/paddleclas/python
|
||||||
|
|
||||||
# CPU推理
|
# CPU推理
|
||||||
python infer.py --model yolov7.onnx --image 000000087038.jpg --device cpu
|
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device cpu
|
||||||
# GPU推理
|
# GPU推理
|
||||||
python infer.py --model yolov7.onnx --image 000000087038.jpg --device gpu
|
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu
|
||||||
# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待)
|
# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待)
|
||||||
python infer.py --model yolov7.onnx --image 000000087038.jpg --device gpu --use_trt True
|
python infer.py --model ResNet50_vd_infer --image ILSVRC2012_val_00000010.jpeg --device gpu --use_trt True
|
||||||
```
|
```
|
||||||
|
|
||||||
运行完成可视化结果如下图所示
|
运行完成后返回结果如下所示
|
||||||
|
|
||||||
## YOLOv7 Python接口
|
|
||||||
|
|
||||||
```
|
```
|
||||||
fastdeploy.vision.detection.YOLOv7(model_file, params_file=None, runtime_option=None, model_format=Frontend.ONNX)
|
ClassifyResult(
|
||||||
|
label_ids: 153,
|
||||||
|
scores: 0.686229,
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
YOLOv7模型加载和初始化,其中model_file为导出的ONNX模型格式
|
## PaddleClasModel Python接口
|
||||||
|
|
||||||
|
```
|
||||||
|
fd.vision.classification.PaddleClasModel(model_file, params_file, config_file, runtime_option=None, model_format=Frontend.PADDLE)
|
||||||
|
```
|
||||||
|
|
||||||
|
PaddleClas模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA)
|
||||||
|
|
||||||
**参数**
|
**参数**
|
||||||
|
|
||||||
> * **model_file**(str): 模型文件路径
|
> * **model_file**(str): 模型文件路径
|
||||||
> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
|
> * **params_file**(str): 参数文件路径
|
||||||
|
> * **config_file**(str): 推理部署配置文件
|
||||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||||
> * **model_format**(Frontend): 模型格式,默认为ONNX
|
> * **model_format**(Frontend): 模型格式,默认为Paddle格式
|
||||||
|
|
||||||
### predict函数
|
### predict函数
|
||||||
|
|
||||||
> ```
|
> ```
|
||||||
> YOLOv7.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
|
> PaddleClasModel.predict(input_image, topk=1)
|
||||||
> ```
|
> ```
|
||||||
>
|
>
|
||||||
> 模型预测结口,输入图像直接输出检测结果。
|
> 模型预测结口,输入图像直接输出检测结果。
|
||||||
>
|
>
|
||||||
> **参数**
|
> **参数**
|
||||||
>
|
>
|
||||||
> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
|
||||||
> > * **conf_threshold**(float): 检测框置信度过滤阈值
|
> > * **topk**(int):返回预测概率最高的topk个分类结果
|
||||||
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
|
|
||||||
|
|
||||||
> **返回**
|
> **返回**
|
||||||
>
|
>
|
||||||
> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
> > 返回`fastdeploy.vision.ClassifyResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
|
||||||
|
|
||||||
### 类成员属性
|
|
||||||
|
|
||||||
> > * **size**(list | tuple): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640]
|
|
||||||
|
|
||||||
## 其它文档
|
## 其它文档
|
||||||
|
|
||||||
- [YOLOv7 模型介绍](..)
|
- [PaddleClas 模型介绍](..)
|
||||||
- [YOLOv7 C++部署](../cpp)
|
- [PaddleClas C++部署](../cpp)
|
||||||
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
- [模型预测结果说明](../../../../../docs/api/vision_results/)
|
||||||
|
@@ -23,14 +23,14 @@ class PaddleClasModel(FastDeployModel):
|
|||||||
model_file,
|
model_file,
|
||||||
params_file,
|
params_file,
|
||||||
config_file,
|
config_file,
|
||||||
backend_option=None,
|
runtime_option=None,
|
||||||
model_format=Frontend.PADDLE):
|
model_format=Frontend.PADDLE):
|
||||||
super(PaddleClasModel, self).__init__(backend_option)
|
super(PaddleClasModel, self).__init__(runtime_option)
|
||||||
|
|
||||||
assert model_format == Frontend.PADDLE, "PaddleClasModel only support model format of Frontend.Paddle now."
|
assert model_format == Frontend.PADDLE, "PaddleClasModel only support model format of Frontend.Paddle now."
|
||||||
self._model = C.vision.classification.PaddleClasModel(model_file, params_file,
|
self._model = C.vision.classification.PaddleClasModel(
|
||||||
config_file, self._runtime_option,
|
model_file, params_file, config_file, self._runtime_option,
|
||||||
model_format)
|
model_format)
|
||||||
assert self.initialized, "PaddleClas model initialize failed."
|
assert self.initialized, "PaddleClas model initialize failed."
|
||||||
|
|
||||||
def predict(self, input_image, topk=1):
|
def predict(self, input_image, topk=1):
|
||||||
|
@@ -14,6 +14,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
def topk_accuracy(topk_list, label_list):
|
def topk_accuracy(topk_list, label_list):
|
||||||
@@ -25,6 +27,7 @@ def topk_accuracy(topk_list, label_list):
|
|||||||
def eval_classify(model, image_file_path, label_file_path, topk=5):
|
def eval_classify(model, image_file_path, label_file_path, topk=5):
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
import cv2
|
import cv2
|
||||||
|
import math
|
||||||
|
|
||||||
result_list = []
|
result_list = []
|
||||||
label_list = []
|
label_list = []
|
||||||
@@ -36,6 +39,7 @@ def eval_classify(model, image_file_path, label_file_path, topk=5):
|
|||||||
label_file_path), "The label_file_path:{} is not a file.".format(
|
label_file_path), "The label_file_path:{} is not a file.".format(
|
||||||
label_file_path)
|
label_file_path)
|
||||||
assert isinstance(topk, int), "The tok:{} is not int type".format(topk)
|
assert isinstance(topk, int), "The tok:{} is not int type".format(topk)
|
||||||
|
|
||||||
with open(label_file_path, 'r') as file:
|
with open(label_file_path, 'r') as file:
|
||||||
lines = file.readlines()
|
lines = file.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
@@ -44,14 +48,30 @@ def eval_classify(model, image_file_path, label_file_path, topk=5):
|
|||||||
label = items[1]
|
label = items[1]
|
||||||
image_label_dict[image_name] = int(label)
|
image_label_dict[image_name] = int(label)
|
||||||
images_num = len(image_label_dict)
|
images_num = len(image_label_dict)
|
||||||
|
twenty_percent_images_num = math.ceil(images_num * 0.2)
|
||||||
|
start_time = 0
|
||||||
|
end_time = 0
|
||||||
|
average_inference_time = 0
|
||||||
|
scores = collections.OrderedDict()
|
||||||
for (image, label), i in zip(image_label_dict.items(),
|
for (image, label), i in zip(image_label_dict.items(),
|
||||||
trange(
|
trange(
|
||||||
images_num, desc='Inference Progress')):
|
images_num, desc='Inference Progress')):
|
||||||
|
if i == twenty_percent_images_num:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
label_list.append([label])
|
label_list.append([label])
|
||||||
image_path = os.path.join(image_file_path, image)
|
image_path = os.path.join(image_file_path, image)
|
||||||
im = cv2.imread(image_path)
|
im = cv2.imread(image_path)
|
||||||
result = model.predict(im, topk)
|
result = model.predict(im, topk)
|
||||||
result_list.append(result.label_ids)
|
result_list.append(result.label_ids)
|
||||||
|
if i == images_num - 1:
|
||||||
|
end_time = time.time()
|
||||||
|
average_inference_time = round(
|
||||||
|
(end_time - start_time) / (images_num - twenty_percent_images_num), 4)
|
||||||
topk_acc_score = topk_accuracy(np.array(result_list), np.array(label_list))
|
topk_acc_score = topk_accuracy(np.array(result_list), np.array(label_list))
|
||||||
return topk_acc_score
|
if topk == 1:
|
||||||
|
scores.update({'topk1': topk_acc_score})
|
||||||
|
elif topk == 5:
|
||||||
|
scores.update({'topk5': topk_acc_score})
|
||||||
|
scores.update({'average_inference_time': average_inference_time})
|
||||||
|
return scores
|
||||||
|
@@ -15,6 +15,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import copy
|
import copy
|
||||||
import collections
|
import collections
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
def eval_detection(model,
|
def eval_detection(model,
|
||||||
@@ -26,7 +27,7 @@ def eval_detection(model,
|
|||||||
from .utils import CocoDetection
|
from .utils import CocoDetection
|
||||||
from .utils import COCOMetric
|
from .utils import COCOMetric
|
||||||
import cv2
|
import cv2
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
if conf_threshold is not None or nms_iou_threshold is not None:
|
if conf_threshold is not None or nms_iou_threshold is not None:
|
||||||
assert conf_threshold is not None and nms_iou_threshold is not None, "The conf_threshold and nms_iou_threshold should be setted at the same time"
|
assert conf_threshold is not None and nms_iou_threshold is not None, "The conf_threshold and nms_iou_threshold should be setted at the same time"
|
||||||
@@ -48,9 +49,15 @@ def eval_detection(model,
|
|||||||
eval_metric = COCOMetric(
|
eval_metric = COCOMetric(
|
||||||
coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False)
|
coco_gt=copy.deepcopy(eval_dataset.coco_gt), classwise=False)
|
||||||
scores = collections.OrderedDict()
|
scores = collections.OrderedDict()
|
||||||
|
twenty_percent_image_num = math.ceil(image_num * 0.2)
|
||||||
|
start_time = 0
|
||||||
|
end_time = 0
|
||||||
|
average_inference_time = 0
|
||||||
for image_info, i in zip(all_image_info,
|
for image_info, i in zip(all_image_info,
|
||||||
trange(
|
trange(
|
||||||
image_num, desc="Inference Progress")):
|
image_num, desc="Inference Progress")):
|
||||||
|
if i == twenty_percent_image_num:
|
||||||
|
start_time = time.time()
|
||||||
im = cv2.imread(image_info["image"])
|
im = cv2.imread(image_info["image"])
|
||||||
im_id = image_info["im_id"]
|
im_id = image_info["im_id"]
|
||||||
if conf_threshold is None and nms_iou_threshold is None:
|
if conf_threshold is None and nms_iou_threshold is None:
|
||||||
@@ -66,8 +73,13 @@ def eval_detection(model,
|
|||||||
'im_id': im_id
|
'im_id': im_id
|
||||||
}
|
}
|
||||||
eval_metric.update(im_id, pred)
|
eval_metric.update(im_id, pred)
|
||||||
|
if i == image_num - 1:
|
||||||
|
end_time = time.time()
|
||||||
|
average_inference_time = round(
|
||||||
|
(end_time - start_time) / (image_num - twenty_percent_image_num), 4)
|
||||||
eval_metric.accumulate()
|
eval_metric.accumulate()
|
||||||
eval_details = eval_metric.details
|
eval_details = eval_metric.details
|
||||||
scores.update(eval_metric.get())
|
scores.update(eval_metric.get())
|
||||||
|
scores.update({'average_inference_time': average_inference_time})
|
||||||
eval_metric.reset()
|
eval_metric.reset()
|
||||||
return scores
|
return scores
|
||||||
|
Reference in New Issue
Block a user