[Model] Support PaddleDetection SSD Model (#630)

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

* Add files via upload

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
totorolin
2022-11-28 19:13:02 +08:00
committed by GitHub
parent 56c39b5f96
commit 941057888a
7 changed files with 202 additions and 4 deletions

View File

@@ -15,6 +15,7 @@
- [YOLOX系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/yolox)
- [FasterRCNN系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/faster_rcnn)
- [MaskRCNN系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/configs/mask_rcnn)
- [SSD系列模型](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/ssd)
## 导出部署模型
@@ -35,16 +36,17 @@
|:---------------------------------------------------------------- |:----- |:----- | :------ |
| [picodet_l_320_coco_lcnet](https://bj.bcebos.com/paddlehub/fastdeploy/picodet_l_320_coco_lcnet.tgz) |23MB | Box AP 42.6% |
| [ppyoloe_crn_l_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz) |200MB | Box AP 51.4% |
| [ppyoloe_plus_crn_m_80e_coco](https://bj.bcebos.com/fastdeploy/models/ppyoloe_plus_crn_m_80e_coco.tgz) |83.3MB | Box AP 49.8% |
| [ppyolo_r50vd_dcn_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolo_r50vd_dcn_1x_coco.tgz) | 180MB | Box AP 44.8% | 暂不支持TensorRT |
| [ppyolov2_r101vd_dcn_365e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ppyolov2_r101vd_dcn_365e_coco.tgz) | 282MB | Box AP 49.7% | 暂不支持TensorRT |
| [yolov3_darknet53_270e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolov3_darknet53_270e_coco.tgz) |237MB | Box AP 39.1% | |
| [yolox_s_300e_coco](https://bj.bcebos.com/paddlehub/fastdeploy/yolox_s_300e_coco.tgz) | 35MB | Box AP 40.4% | |
| [faster_rcnn_r50_vd_fpn_2x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/faster_rcnn_r50_vd_fpn_2x_coco.tgz) | 160MB | Box AP 40.8%| 暂不支持TensorRT |
| [mask_rcnn_r50_1x_coco](https://bj.bcebos.com/paddlehub/fastdeploy/mask_rcnn_r50_1x_coco.tgz) | 128M | Box AP 37.4%, Mask AP 32.8%| 暂不支持TensorRT、ORT |
| [ssd_mobilenet_v1_300_120e_voc](https://bj.bcebos.com/paddlehub/fastdeploy/ssd_mobilenet_v1_300_120e_voc.tgz) | 21.7M | Box AP 73.8%| 暂不支持TensorRT、ORT |
| [ssd_vgg16_300_240e_voc](https://bj.bcebos.com/paddlehub/fastdeploy/ssd_vgg16_300_240e_voc.tgz) | 97.7M | Box AP 77.8%| 暂不支持TensorRT、ORT |
| [ssdlite_mobilenet_v1_300_coco](https://bj.bcebos.com/paddlehub/fastdeploy/ssdlite_mobilenet_v1_300_coco.tgz) | 24.4M | | 暂不支持TensorRT、ORT |
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)
- [服务化部署](serving)
- [C++部署](cpp)

View File

@@ -0,0 +1,100 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseCpu();
option.UsePaddleBackend();
auto model = fastdeploy::vision::detection::SSD(model_file, params_file,
config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im_bak, res, 0.5);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::detection::SSD(model_file, params_file,
config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im_bak, res, 0.5);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./ssd_dirname ./test.jpeg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -40,6 +40,7 @@ fastdeploy.vision.detection.YOLOv3(model_file, params_file, config_file, runtime
fastdeploy.vision.detection.PPYOLO(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.FasterRCNN(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.MaskRCNN(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
fastdeploy.vision.detection.SSD(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
```
PaddleDetection模型加载和初始化其中model_file params_file为导出的Paddle部署模型格式, config_file为PaddleDetection同时导出的部署配置yaml文件

View File

@@ -0,0 +1,50 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
required=True,
help="Path of PaddleDetection model directory")
parser.add_argument(
"--image", required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
return option
args = parse_arguments()
model_file = os.path.join(args.model_dir, "model.pdmodel")
params_file = os.path.join(args.model_dir, "model.pdiparams")
config_file = os.path.join(args.model_dir, "infer_cfg.yml")
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.detection.SSD(
model_file, params_file, config_file, runtime_option=runtime_option)
# 预测图片检测结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 预测结果可视化
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

View File

@@ -160,6 +160,22 @@ class FASTDEPLOY_DECL MaskRCNN : public PPDetBase {
virtual std::string ModelName() const { return "PaddleDetection/MaskRCNN"; }
};
class FASTDEPLOY_DECL SSD : public PPDetBase {
public:
SSD(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
: PPDetBase(model_file, params_file, config_file, custom_option,
model_format) {
valid_cpu_backends = {Backend::PDINFER, Backend::LITE};
valid_gpu_backends = {Backend::PDINFER};
initialized = Initialize();
}
virtual std::string ModelName() const { return "PaddleDetection/SSD"; }
};
} // namespace detection
} // namespace vision
} // namespace fastdeploy

View File

@@ -108,5 +108,9 @@ void BindPPDet(pybind11::module& m) {
pybind11::class_<vision::detection::MaskRCNN, vision::detection::PPDetBase>(m, "MaskRCNN")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
pybind11::class_<vision::detection::SSD, vision::detection::PPDetBase>(m, "SSD")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>());
}
} // namespace fastdeploy

View File

@@ -272,3 +272,28 @@ class MaskRCNN(PPYOLOE):
raise Exception(
"batch_predict is not supported for MaskRCNN model now.")
class SSD(PPYOLOE):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a SSD model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ssd/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ssd/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "SSD model only support model format of ModelFormat.Paddle now."
self._model = C.vision.detection.SSD(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "SSD model initialize failed."