Files
FastDeploy/docs/cn/faq/develop_a_new_model.md
ziqi-jin 828a81d81c [Doc] Add tutorial of supporting new models (#418)
* first commit for yolov7

* pybind for yolov7

* CPP README.md

* CPP README.md

* modified yolov7.cc

* README.md

* python file modify

* delete license in fastdeploy/

* repush the conflict part

* README.md modified

* README.md modified

* file path modified

* file path modified

* file path modified

* file path modified

* file path modified

* README modified

* README modified

* move some helpers to private

* add examples for yolov7

* api.md modified

* api.md modified

* api.md modified

* YOLOv7

* yolov7 release link

* yolov7 release link

* yolov7 release link

* copyright

* change some helpers to private

* change variables to const and fix documents.

* gitignore

* Transfer some funtions to private member of class

* Transfer some funtions to private member of class

* Merge from develop (#9)

* Fix compile problem in different python version (#26)

* fix some usage problem in linux

* Fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

* Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents

* add convert processor to vision (#27)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* fixed examples/CMakeLists.txt to avoid conflicts

* add convert processor to vision

* format examples/CMakeLists summary

* Fix bug while the inference result is empty with YOLOv5 (#29)

* Add multi-label function for yolov5

* Update README.md

Update doc

* Update fastdeploy_runtime.cc

fix variable option.trt_max_shape wrong name

* Update runtime_option.md

Update resnet model dynamic shape setting name from images to x

* Fix bug when inference result boxes are empty

* Delete detection.py

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

* first commit for yolor

* for merge

* Develop (#11)

* Fix compile problem in different python version (#26)

* fix some usage problem in linux

* Fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

* Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents

* add convert processor to vision (#27)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* fixed examples/CMakeLists.txt to avoid conflicts

* add convert processor to vision

* format examples/CMakeLists summary

* Fix bug while the inference result is empty with YOLOv5 (#29)

* Add multi-label function for yolov5

* Update README.md

Update doc

* Update fastdeploy_runtime.cc

fix variable option.trt_max_shape wrong name

* Update runtime_option.md

Update resnet model dynamic shape setting name from images to x

* Fix bug when inference result boxes are empty

* Delete detection.py

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

* Yolor (#16)

* Develop (#11) (#12)

* Fix compile problem in different python version (#26)

* fix some usage problem in linux

* Fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

* Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents

* add convert processor to vision (#27)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* fixed examples/CMakeLists.txt to avoid conflicts

* add convert processor to vision

* format examples/CMakeLists summary

* Fix bug while the inference result is empty with YOLOv5 (#29)

* Add multi-label function for yolov5

* Update README.md

Update doc

* Update fastdeploy_runtime.cc

fix variable option.trt_max_shape wrong name

* Update runtime_option.md

Update resnet model dynamic shape setting name from images to x

* Fix bug when inference result boxes are empty

* Delete detection.py

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

* Develop (#13)

* Fix compile problem in different python version (#26)

* fix some usage problem in linux

* Fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

* Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents

* add convert processor to vision (#27)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* fixed examples/CMakeLists.txt to avoid conflicts

* add convert processor to vision

* format examples/CMakeLists summary

* Fix bug while the inference result is empty with YOLOv5 (#29)

* Add multi-label function for yolov5

* Update README.md

Update doc

* Update fastdeploy_runtime.cc

fix variable option.trt_max_shape wrong name

* Update runtime_option.md

Update resnet model dynamic shape setting name from images to x

* Fix bug when inference result boxes are empty

* Delete detection.py

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* documents

* Develop (#14)

* Fix compile problem in different python version (#26)

* fix some usage problem in linux

* Fix compile problem

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>

* Add PaddleDetetion/PPYOLOE model support (#22)

* add ppdet/ppyoloe

* Add demo code and documents

* add convert processor to vision (#27)

* update .gitignore

* Added checking for cmake include dir

* fixed missing trt_backend option bug when init from trt

* remove un-need data layout and add pre-check for dtype

* changed RGB2BRG to BGR2RGB in ppcls model

* add model_zoo yolov6 c++/python demo

* fixed CMakeLists.txt typos

* update yolov6 cpp/README.md

* add yolox c++/pybind and model_zoo demo

* move some helpers to private

* fixed CMakeLists.txt typos

* add normalize with alpha and beta

* add version notes for yolov5/yolov6/yolox

* add copyright to yolov5.cc

* revert normalize

* fixed some bugs in yolox

* fixed examples/CMakeLists.txt to avoid conflicts

* add convert processor to vision

* format examples/CMakeLists summary

* Fix bug while the inference result is empty with YOLOv5 (#29)

* Add multi-label function for yolov5

* Update README.md

Update doc

* Update fastdeploy_runtime.cc

fix variable option.trt_max_shape wrong name

* Update runtime_option.md

Update resnet model dynamic shape setting name from images to x

* Fix bug when inference result boxes are empty

* Delete detection.py

Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>
Co-authored-by: Jason <928090362@qq.com>

* add is_dynamic for YOLO series (#22)

* modify ppmatting backend and docs

* modify ppmatting docs

* fix the PPMatting size problem

* fix LimitShort's log

* retrigger ci

* modify PPMatting docs

* modify the way  for dealing with  LimitShort

* change develop_a_new_model.md dir

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: root <root@bjyz-sys-gpu-kongming3.bjyz.baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
Co-authored-by: huangjianhui <852142024@qq.com>
Co-authored-by: Jason <928090362@qq.com>
2022-10-24 10:38:29 +08:00

11 KiB
Raw Blame History

FastDeploy集成新模型流程

在FastDeploy里面新增一个模型包括增加C++/Python的部署支持。 本文以torchvision v0.12.0中的ResNet50模型为例介绍使用FastDeploy做外部模型集成具体包括如下3步。

步骤 说明 创建或修改的文件
1 在fastdeploy/vision相应任务模块增加模型实现 resnet.h、resnet.cc、vision.h
2 通过pybind完成Python接口绑定 resnet_pybind.cc、classification_pybind.cc
3 实现Python相应调用接口 resnet.py、__init__.py

在完成上述3步之后一个外部模型就集成好了。
如果您想为FastDeploy贡献代码还需要为新增模型添加测试代码、说明文档和代码注释可在测试中查看。

模型集成

模型准备

在集成外部模型之前,先要将训练好的模型(.pt.pdparams 等转换成FastDeploy支持部署的模型格式.onnx.pdmodel。多数开源仓库会提供模型转换脚本可以直接利用脚本做模型的转换。由于torchvision没有提供转换脚本因此手动编写转换脚本本文中将 torchvison.models.resnet50 转换为 resnet50.onnx 参考代码如下:

import torch
import torchvision.models as models
model = models.resnet50(pretrained=True)
batch_size = 1  #批处理大小
input_shape = (3, 224, 224)   #输入数据,改成自己的输入shape
model.eval()
x = torch.randn(batch_size, *input_shape)	# 生成张量
export_onnx_file = "resnet50.onnx"			# 目的ONNX文件名
torch.onnx.export(model,
                    x,
                    export_onnx_file,
                    opset_version=12,
                    input_names=["input"],	# 输入名
                    output_names=["output"],	# 输出名
                    dynamic_axes={"input":{0:"batch_size"},  # 批处理变量
                                    "output":{0:"batch_size"}})

执行上述脚本将会得到 resnet50.onnx 文件。

C++部分

  • 创建resnet.h文件
    • 创建位置
      • FastDeploy/fastdeploy/vision/classification/contrib/resnet.h (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.h)
    • 创建内容
      • 首先在resnet.h中创建 ResNet类并继承FastDeployModel父类之后声明PredictInitializePreprocessPostprocess构造函数,以及必要的变量,具体的代码细节请参考resnet.h
class FASTDEPLOY_DECL ResNet : public FastDeployModel {
 public:
  ResNet(...);
  virtual bool Predict(...);
 private:
  bool Initialize();
  bool Preprocess(...);
  bool Postprocess(...);
};
  • 创建resnet.cc文件
    • 创建位置
      • FastDeploy/fastdeploy/vision/classification/contrib/resnet.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名.cc)
    • 创建内容
      • resnet.cc中实现resnet.h中声明函数的具体逻辑,其中PreProcessPostProcess需要参考源官方库的前后处理逻辑复现ResNet每个函数具体逻辑如下具体的代码请参考resnet.cc
ResNet::ResNet(...) {
  // 构造函数逻辑
  // 1. 指定 Backend 2. 设置RuntimeOption 3. 调用Initialize()函数
}
bool ResNet::Initialize() {
  // 初始化逻辑
  // 1. 全局变量赋值 2. 调用InitRuntime()函数
  return true;
}
bool ResNet::Preprocess(Mat* mat, FDTensor* output) {
// 前处理逻辑
// 1. Resize 2. BGR2RGB 3. Normalize 4. HWC2CHW 5. 处理结果存入 FDTensor类中  
  return true;
}
bool ResNet::Postprocess(FDTensor& infer_result, ClassifyResult* result, int topk) {
  //后处理逻辑
  // 1. Softmax 2. Choose topk labels 3. 结果存入 ClassifyResult类
  return true;
}
bool ResNet::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
  Preprocess(...)
  Infer(...)
  Postprocess(...)
  return true;
}

  • vision.h文件中加入新增模型文件
    • 修改位置
      • FastDeploy/fastdeploy/vision.h
    • 修改内容
#ifdef ENABLE_VISION
#include "fastdeploy/vision/classification/contrib/resnet.h"
#endif

Pybind部分

  • 创建Pybind文件
    • 创建位置
      • FastDeploy/fastdeploy/vision/classification/contrib/resnet_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/外部模型/模型名_pybind.cc)
    • 创建内容
      • 利用Pybind将C++中的函数变量绑定到Python中具体代码请参考resnet_pybind.cc
void BindResNet(pybind11::module& m) {
  pybind11::class_<vision::classification::ResNet, FastDeployModel>(
      m, "ResNet")
      .def(pybind11::init<std::string, std::string, RuntimeOption, ModelFormat>())
      .def("predict", ...)
      .def_readwrite("size", &vision::classification::ResNet::size)
      .def_readwrite("mean_vals", &vision::classification::ResNet::mean_vals)
      .def_readwrite("std_vals", &vision::classification::ResNet::std_vals);
}
  • 调用Pybind函数
    • 修改位置
      • FastDeploy/fastdeploy/vision/classification/classification_pybind.cc (FastDeploy/C++代码存放位置/视觉模型/任务名称/任务名称}_pybind.cc)
    • 修改内容
void BindResNet(pybind11::module& m);
void BindClassification(pybind11::module& m) {
  auto classification_module =
      m.def_submodule("classification", "Image classification models.");
  BindResNet(classification_module);
}

Python部分

  • 创建resnet.py文件
    • 创建位置
      • FastDeploy/python/fastdeploy/vision/classification/contrib/resnet.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/外部模型/模型名.py)
    • 创建内容
      • 创建ResNet类继承自FastDeployModel实现 \_\_init\_\_、Pybind绑定的函数predict())、以及对Pybind绑定的全局变量进行赋值和获取的函数,具体代码请参考resnet.py
class ResNet(FastDeployModel):
    def __init__(self, ...):
        self._model = C.vision.classification.ResNet(...)
    def predict(self, input_image, topk=1):
        return self._model.predict(input_image, topk)
    @property
    def size(self):
        return self._model.size
    @size.setter
    def size(self, wh):
        ...

  • 导入ResNet类
    • 修改位置
      • FastDeploy/python/fastdeploy/vision/classification/__init__.py (FastDeploy/Python代码存放位置/fastdeploy/视觉模型/任务名称/__init__.py)
    • 修改内容
from .contrib.resnet import ResNet

测试

编译

  • C++
    • 位置FastDeploy/
mkdir build & cd build
cmake .. -DENABLE_ORT_BACKEND=ON -DENABLE_VISION=ON -DCMAKE_INSTALL_PREFIX=${PWD/fastdeploy-0.0.3
-DENABLE_PADDLE_BACKEND=ON -DENABLE_TRT_BACKEND=ON -DWITH_GPU=ON -DTRT_DIRECTORY=/PATH/TO/TensorRT/
make -j8
make install

编译会得到 build/fastdeploy-0.0.3/。

  • Python
    • 位置FastDeploy/python/
export TRT_DIRECTORY=/PATH/TO/TensorRT/    # 如果用TensorRT 需要填写TensorRT所在位置并开启 ENABLE_TRT_BACKEND
export ENABLE_TRT_BACKEND=ON
export WITH_GPU=ON
export ENABLE_PADDLE_BACKEND=ON
export ENABLE_OPENVINO_BACKEND=ON
export ENABLE_VISION=ON
export ENABLE_ORT_BACKEND=ON
python setup.py build
python setup.py bdist_wheel
cd dist
pip install fastdeploy_gpu_python-版本号-cpxx-cpxxm-系统架构.whl

编写测试代码

  • 创建位置: FastDeploy/examples/vision/classification/resnet/ (FastDeploy/示例目录/视觉模型/任务名称/模型名/)
  • 创建目录结构
.
├── cpp
│   ├── CMakeLists.txt
│   ├── infer.cc    // C++ 版本测试代码
│   └── README.md   // C++版本使用文档
├── python
│   ├── infer.py    // Python 版本测试代码
│   └── README.md   // Python版本使用文档
└── README.md   // ResNet 模型集成说明文档
  • C++
    • 编写CmakeLists文件、C++ 代码以及 README.md 内容请参考cpp/
    • 编译 infer.cc
      • 位置FastDeploy/examples/vision/classification/resnet/cpp/
mkdir build & cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=/PATH/TO/FastDeploy/build/fastdeploy-0.0.3/
make
  • Python
    • Python 代码以及 README.md 内容请参考python/

为代码添加注释

为了方便用户理解代码,我们需要为新增代码添加注释,添加注释方法可参考如下示例。

  • C++ 代码 您需要在resnet.h文件中为函数和变量增加注释有如下三种注释方式具体可参考resnet.h
/** \brief Predict for the input "im", the result will be saved in "result".
*
* \param[in] im Input image for inference.
* \param[in] result Saving the inference result.
* \param[in] topk The length of return values, e.g., if topk==2, the result will include the 2 most possible class label for input image.
*/
virtual bool Predict(cv::Mat* im, ClassifyResult* result, int topk = 1);
/// Tuple of (width, height)
std::vector<int> size;
/*! @brief Initialize for ResNet model, assign values to the global variables and call InitRuntime()
*/
bool Initialize();
  • Python 代码 你需要为resnet.py文件中的函数和变量增加适当的注释示例如下具体可参考resnet.py
  def predict(self, input_image, topk=1):
    """Classify an input image
    :param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
    :param topk: (int)The topk result by the classify confidence score, default 1
    :return: ClassifyResult
    """
    return self._model.predict(input_image, topk)

对于集成模型过程中的其他文件,您也可以对实现的细节添加适当的注释说明。