From bf5affb510d7fdfd42f4fc02fd9ba94d6c065595 Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Wed, 10 Aug 2022 07:19:47 +0000 Subject: [PATCH] add paddleclas draft doct --- csrc/fastdeploy/vision.h | 2 +- .../{ => classification}/ppcls/model.cc | 18 +-- .../vision/{ => classification}/ppcls/model.h | 26 ++++- .../ppcls/ppcls_pybind.cc | 5 +- .../classification/paddleclas/README.md | 28 +++++ .../paddleclas/cpp/CMakeLists.txt | 14 +++ .../classification/paddleclas/cpp/README.md | 77 +++++++++++++ .../classification/paddleclas/cpp/infer.cc | 105 ++++++++++++++++++ .../paddleclas/python/README.md | 71 ++++++++++++ .../classification/paddleclas/python/infer.py | 47 ++++++++ fastdeploy/vision/classification/__init__.py | 30 +++++ .../{ => classification}/ppcls/__init__.py | 12 +- 12 files changed, 410 insertions(+), 25 deletions(-) rename csrc/fastdeploy/vision/{ => classification}/ppcls/model.cc (90%) rename csrc/fastdeploy/vision/{ => classification}/ppcls/model.h (66%) rename csrc/fastdeploy/vision/{ => classification}/ppcls/ppcls_pybind.cc (81%) create mode 100644 examples/vision/classification/paddleclas/README.md create mode 100644 examples/vision/classification/paddleclas/cpp/CMakeLists.txt create mode 100644 examples/vision/classification/paddleclas/cpp/README.md create mode 100644 examples/vision/classification/paddleclas/cpp/infer.cc create mode 100644 examples/vision/classification/paddleclas/python/README.md create mode 100644 examples/vision/classification/paddleclas/python/infer.py create mode 100644 fastdeploy/vision/classification/__init__.py rename fastdeploy/vision/{ => classification}/ppcls/__init__.py (75%) diff --git a/csrc/fastdeploy/vision.h b/csrc/fastdeploy/vision.h index 21371b5a1..fd0824e48 100644 --- a/csrc/fastdeploy/vision.h +++ b/csrc/fastdeploy/vision.h @@ -33,7 +33,7 @@ #include "fastdeploy/vision/faceid/contrib/partial_fc.h" #include "fastdeploy/vision/faceid/contrib/vpl.h" #include "fastdeploy/vision/matting/contrib/modnet.h" -#include "fastdeploy/vision/ppcls/model.h" +#include "fastdeploy/vision/classification/ppcls/model.h" #include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/ppseg/model.h" #endif diff --git a/csrc/fastdeploy/vision/ppcls/model.cc b/csrc/fastdeploy/vision/classification/ppcls/model.cc similarity index 90% rename from csrc/fastdeploy/vision/ppcls/model.cc rename to csrc/fastdeploy/vision/classification/ppcls/model.cc index c4e5b767c..607d1114e 100644 --- a/csrc/fastdeploy/vision/ppcls/model.cc +++ b/csrc/fastdeploy/vision/classification/ppcls/model.cc @@ -12,15 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "fastdeploy/vision/ppcls/model.h" +#include "fastdeploy/vision/classification/ppcls/model.h" #include "fastdeploy/vision/utils/utils.h" #include "yaml-cpp/yaml.h" namespace fastdeploy { namespace vision { -namespace ppcls { +namespace classification { -Model::Model(const std::string& model_file, const std::string& params_file, +PaddleClasModel::PaddleClasModel(const std::string& model_file, const std::string& params_file, const std::string& config_file, const RuntimeOption& custom_option, const Frontend& model_format) { config_file_ = config_file; @@ -33,7 +33,7 @@ Model::Model(const std::string& model_file, const std::string& params_file, initialized = Initialize(); } -bool Model::Initialize() { +bool PaddleClasModel::Initialize() { if (!BuildPreprocessPipelineFromConfig()) { FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; @@ -46,7 +46,7 @@ bool Model::Initialize() { return true; } -bool Model::BuildPreprocessPipelineFromConfig() { +bool PaddleClasModel::BuildPreprocessPipelineFromConfig() { processors_.clear(); YAML::Node cfg; try { @@ -91,7 +91,7 @@ bool Model::BuildPreprocessPipelineFromConfig() { return true; } -bool Model::Preprocess(Mat* mat, FDTensor* output) { +bool PaddleClasModel::Preprocess(Mat* mat, FDTensor* output) { for (size_t i = 0; i < processors_.size(); ++i) { if (!(*(processors_[i].get()))(mat)) { FDERROR << "Failed to process image data in " << processors_[i]->Name() @@ -109,7 +109,7 @@ bool Model::Preprocess(Mat* mat, FDTensor* output) { return true; } -bool Model::Postprocess(const FDTensor& infer_result, ClassifyResult* result, +bool PaddleClasModel::Postprocess(const FDTensor& infer_result, ClassifyResult* result, int topk) { int num_classes = infer_result.shape[1]; const float* infer_result_buffer = @@ -124,7 +124,7 @@ bool Model::Postprocess(const FDTensor& infer_result, ClassifyResult* result, return true; } -bool Model::Predict(cv::Mat* im, ClassifyResult* result, int topk) { +bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) { Mat mat(*im); std::vector processed_data(1); if (!Preprocess(&mat, &(processed_data[0]))) { @@ -148,6 +148,6 @@ bool Model::Predict(cv::Mat* im, ClassifyResult* result, int topk) { return true; } -} // namespace ppcls +} // namespace classification } // namespace vision } // namespace fastdeploy diff --git a/csrc/fastdeploy/vision/ppcls/model.h b/csrc/fastdeploy/vision/classification/ppcls/model.h similarity index 66% rename from csrc/fastdeploy/vision/ppcls/model.h rename to csrc/fastdeploy/vision/classification/ppcls/model.h index 71800a7d7..b412bcb2e 100644 --- a/csrc/fastdeploy/vision/ppcls/model.h +++ b/csrc/fastdeploy/vision/classification/ppcls/model.h @@ -19,21 +19,21 @@ namespace fastdeploy { namespace vision { -namespace ppcls { +namespace classification { -class FASTDEPLOY_DECL Model : public FastDeployModel { +class FASTDEPLOY_DECL PaddleClasModel : public FastDeployModel { public: - Model(const std::string& model_file, const std::string& params_file, + PaddleClasModel(const std::string& model_file, const std::string& params_file, const std::string& config_file, const RuntimeOption& custom_option = RuntimeOption(), const Frontend& model_format = Frontend::PADDLE); - std::string ModelName() const { return "ppclas-classify"; } + virtual std::string ModelName() const { return "PaddleClas/Model"; } // TODO(jiangjiajun) Batch is on the way virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1); - private: + protected: bool Initialize(); bool BuildPreprocessPipelineFromConfig(); @@ -46,6 +46,20 @@ class FASTDEPLOY_DECL Model : public FastDeployModel { std::vector> processors_; std::string config_file_; }; -} // namespace ppcls + +typedef PaddleClasModel PPLCNet; +typedef PaddleClasModel PPLCNetv2; +typedef PaddleClasModel EfficientNet; +typedef PaddleClasModel GhostNet; +typedef PaddleClasModel MobileNetv1; +typedef PaddleClasModel MobileNetv2; +typedef PaddleClasModel MobileNetv3; +typedef PaddleClasModel ShuffleNetv2; +typedef PaddleClasModel SqueezeNet; +typedef PaddleClasModel Inceptionv3; +typedef PaddleClasModel PPHGNet; +typedef PaddleClasModel ResNet50vd; +typedef PaddleClasModel SwinTransformer; +} // namespace classification } // namespace vision } // namespace fastdeploy diff --git a/csrc/fastdeploy/vision/ppcls/ppcls_pybind.cc b/csrc/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc similarity index 81% rename from csrc/fastdeploy/vision/ppcls/ppcls_pybind.cc rename to csrc/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc index 10ff5ee10..190818ac3 100644 --- a/csrc/fastdeploy/vision/ppcls/ppcls_pybind.cc +++ b/csrc/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc @@ -15,12 +15,11 @@ namespace fastdeploy { void BindPPCls(pybind11::module& m) { - auto ppcls_module = m.def_submodule("ppcls", "Module to deploy PaddleClas."); - pybind11::class_(ppcls_module, "Model") + pybind11::class_(m, "PaddleClasModel") .def(pybind11::init()) .def("predict", - [](vision::ppcls::Model& self, pybind11::array& data, int topk = 1) { + [](vision::classification::PaddleClasModel& self, pybind11::array& data, int topk = 1) { auto mat = PyArrayToCvMat(data); vision::ClassifyResult res; self.Predict(&mat, &res, topk); diff --git a/examples/vision/classification/paddleclas/README.md b/examples/vision/classification/paddleclas/README.md new file mode 100644 index 000000000..a834a25e9 --- /dev/null +++ b/examples/vision/classification/paddleclas/README.md @@ -0,0 +1,28 @@ +# PaddleClas 模型部署 + +## 模型版本说明 + +- [PaddleClas Release/2.4](https://github.com/PaddlePaddle/PaddleClas) + +## 准备PaddleClas部署模型 + +PaddleClas模型导出,请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/inference_deployment/export_model.md#2-%E5%88%86%E7%B1%BB%E6%A8%A1%E5%9E%8B%E5%AF%BC%E5%87%BA) + +注意:PaddleClas导出的模型仅包含`inference.pdmodel`和`inference.pdiparams`两个文档,但为了满足部署的需求,同时也需准备其提供的[inference_cls.yaml](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/deploy/configs/inference_cls.yaml)文件,FastDeploy会从yaml文件中获取模型在推理时需要的预处理信息,开发者可直接下载此文件使用。但需根据自己的需求修改yaml文件中的配置参数。 + + +## 下载预训练模型 + +为了方便开发者的测试,下面提供了PaddleClas导出的部分模型(含inference_cls.yaml文件),开发者可直接下载使用。 + +| 模型 | 大小 |输入Shape | 精度 | +|:---------------------------------------------------------------- |:----- |:----- | :----- | +| [PPLCNet]() | 141MB | 224x224 |51.4% | +| [PPLCNetv2]() | 10MB | 224x224 |51.4% | +| [EfficientNet]() | | 224x224 | | + + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) diff --git a/examples/vision/classification/paddleclas/cpp/CMakeLists.txt b/examples/vision/classification/paddleclas/cpp/CMakeLists.txt new file mode 100644 index 000000000..fea1a2888 --- /dev/null +++ b/examples/vision/classification/paddleclas/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.12) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/classification/paddleclas/cpp/README.md b/examples/vision/classification/paddleclas/cpp/README.md new file mode 100644 index 000000000..2dab72beb --- /dev/null +++ b/examples/vision/classification/paddleclas/cpp/README.md @@ -0,0 +1,77 @@ +# YOLOv7 C++部署示例 + +本目录下提供`infer.cc`快速完成YOLOv7在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/compile/prebuilt_libraries.md) + +以Linux上CPU推理为例,在本目录执行如下命令即可完成编译测试 + +``` +mkdir build +cd build +wget https://xxx.tgz +tar xvf fastdeploy-linux-x64-0.2.0.tgz +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-0.2.0 +make -j + +#下载官方转换好的yolov7模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000087038.jpg + + +# CPU推理 +./infer_demo yolov7.onnx 000000087038.jpg 0 +# GPU推理 +./infer_demo yolov7.onnx 000000087038.jpg 1 +# GPU上TensorRT推理 +./infer_demo yolov7.onnx 000000087038.jpg 2 +``` + +## YOLOv7 C++接口 + +### YOLOv7类 + +``` +fastdeploy::vision::detection::YOLOv7( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const Frontend& model_format = Frontend::ONNX) +``` + +YOLOv7模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX时,此参数传入空字符串即可 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ``` +> YOLOv7::Predict(cv::Mat* im, DetectionResult* result, +> float conf_threshold = 0.25, +> float nms_iou_threshold = 0.5) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,各个框的置信度, DetectionResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) +> > * **conf_threshold**: 检测框置信度过滤阈值 +> > * **nms_iou_threshold**: NMS处理过程中iou阈值 + +### 类成员变量 + +> > * **size**(vector): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640] + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) diff --git a/examples/vision/classification/paddleclas/cpp/infer.cc b/examples/vision/classification/paddleclas/cpp/infer.cc new file mode 100644 index 000000000..1ddca8f1c --- /dev/null +++ b/examples/vision/classification/paddleclas/cpp/infer.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +void CpuInfer(const std::string& model_file, const std::string& image_file) { + auto model = fastdeploy::vision::detection::YOLOv7(model_file); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::detection::YOLOv7(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("images", {1, 3, 640, 640}); + auto model = fastdeploy::vision::detection::YOLOv7(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + + fastdeploy::vision::DetectionResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + + auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res); + cv::imwrite("vis_result.jpg", vis_im); + std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout << "Usage: infer_demo path/to/model path/to/image run_option, " + "e.g ./infer_model ./yolov7.onnx ./test.jpeg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/classification/paddleclas/python/README.md b/examples/vision/classification/paddleclas/python/README.md new file mode 100644 index 000000000..972ff15eb --- /dev/null +++ b/examples/vision/classification/paddleclas/python/README.md @@ -0,0 +1,71 @@ +# PaddleClas模型 Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/quick_start/requirements.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/quick_start/install.md) + +本目录下提供`infer.py`快速完成YOLOv7在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +``` +# 下载yolov7模型文件和测试图片 +wget https://bj.bcebos.com/paddlehub/fastdeploy/yolov7.onnx +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + + +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd examples/vison/detection/yolov7/python/ + +# CPU推理 +python infer.py --model yolov7.onnx --image 000000087038.jpg --device cpu +# GPU推理 +python infer.py --model yolov7.onnx --image 000000087038.jpg --device gpu +# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model yolov7.onnx --image 000000087038.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 + +## YOLOv7 Python接口 + +``` +fastdeploy.vision.detection.YOLOv7(model_file, params_file=None, runtime_option=None, model_format=Frontend.ONNX) +``` + +YOLOv7模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(Frontend): 模型格式,默认为ONNX + +### predict函数 + +> ``` +> YOLOv7.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> ``` +> +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 +> > * **conf_threshold**(float): 检测框置信度过滤阈值 +> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 + +> **返回** +> +> > 返回`fastdeploy.vision.DetectionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +### 类成员属性 + +> > * **size**(list | tuple): 通过此参数修改预处理过程中resize的大小,包含两个整型元素,表示[width, height], 默认值为[640, 640] + +## 其它文档 + +- [YOLOv7 模型介绍](..) +- [YOLOv7 C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) diff --git a/examples/vision/classification/paddleclas/python/infer.py b/examples/vision/classification/paddleclas/python/infer.py new file mode 100644 index 000000000..b3a02be2e --- /dev/null +++ b/examples/vision/classification/paddleclas/python/infer.py @@ -0,0 +1,47 @@ +import fastdeploy as fd +import cv2 + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of PaddleClas model.") + parser.add_argument( + "--image", required=True, help="Path of test image file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("images", [1, 3, 640, 640]) + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.classification.PaddleClasModel(args.model, runtime_option=runtime_option) + +# 预测图片分类结果 +im = cv2.imread(args.image) +result = model.predict(im) +print(result) diff --git a/fastdeploy/vision/classification/__init__.py b/fastdeploy/vision/classification/__init__.py new file mode 100644 index 000000000..3ebcf9c2f --- /dev/null +++ b/fastdeploy/vision/classification/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +from .ppcls import PaddleClasModel + +PPLCNet = PaddleClasModel +PPLCNetv2 = PaddleClasModel +EfficientNet = PaddleClasModel +GhostNet = PaddleClasModel +MobileNetv1 = PaddleClasModel +MobileNetv2 = PaddleClasModel +MobileNetv3 = PaddleClasModel +ShuffleNetv2 = PaddleClasModel +SqueezeNet = PaddleClasModel +Inceptionv3 = PaddleClasModel +PPHGNet = PaddleClasModel +ResNet50vd = PaddleClasModel +SwinTransformer = PaddleClasModel diff --git a/fastdeploy/vision/ppcls/__init__.py b/fastdeploy/vision/classification/ppcls/__init__.py similarity index 75% rename from fastdeploy/vision/ppcls/__init__.py rename to fastdeploy/vision/classification/ppcls/__init__.py index c43a31084..3207e1e5c 100644 --- a/fastdeploy/vision/ppcls/__init__.py +++ b/fastdeploy/vision/classification/ppcls/__init__.py @@ -14,21 +14,21 @@ from __future__ import absolute_import import logging -from ... import FastDeployModel, Frontend -from ... import c_lib_wrap as C +from .... import FastDeployModel, Frontend +from .... import c_lib_wrap as C -class Model(FastDeployModel): +class PaddleClasModel(FastDeployModel): def __init__(self, model_file, params_file, config_file, backend_option=None, model_format=Frontend.PADDLE): - super(Model, self).__init__(backend_option) + super(PaddleClasModel, self).__init__(backend_option) - assert model_format == Frontend.PADDLE, "PaddleClas only support model format of Frontend.Paddle now." - self._model = C.vision.ppcls.Model(model_file, params_file, + assert model_format == Frontend.PADDLE, "PaddleClasModel only support model format of Frontend.Paddle now." + self._model = C.vision.classification.PaddleClasModel(model_file, params_file, config_file, self._runtime_option, model_format) assert self.initialized, "PaddleClas model initialize failed."