From 718698a32a5afe64b18eb8a18fbf0d6d954dbcd0 Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Wed, 26 Oct 2022 14:30:04 +0800 Subject: [PATCH] [Model] add RobustVideoMatting model (#400) * add yolov5cls * fixed bugs * fixed bugs * fixed preprocess bug * add yolov5cls readme * deal with comments * Add YOLOv5Cls Note * add yolov5cls test * add rvm support * support rvm model * add rvm demo * fixed bugs * add rvm readme * add TRT support * add trt support * add rvm test * add EXPORT.md * rename export.md * rm poros doxyen * deal with comments * deal with comments * add rvm video_mode note Co-authored-by: Jason Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com> --- examples/vision/matting/README.md | 1 + .../vision/matting/modnet/python/README.md | 6 +- examples/vision/matting/rvm/README.md | 30 +++ .../vision/matting/rvm/cpp/CMakeLists.txt | 14 ++ examples/vision/matting/rvm/cpp/README.md | 87 +++++++++ examples/vision/matting/rvm/cpp/infer.cc | 131 +++++++++++++ examples/vision/matting/rvm/export.md | 116 +++++++++++ examples/vision/matting/rvm/python/README.md | 88 +++++++++ examples/vision/matting/rvm/python/infer.py | 112 +++++++++++ fastdeploy/backends/poros/common/iengine.h | 5 - fastdeploy/backends/tensorrt/trt_backend.cc | 2 +- fastdeploy/vision.h | 1 + fastdeploy/vision/matting/contrib/modnet.cc | 2 +- fastdeploy/vision/matting/contrib/rvm.cc | 182 ++++++++++++++++++ fastdeploy/vision/matting/contrib/rvm.h | 92 +++++++++ .../vision/matting/contrib/rvm_pybind.cc | 34 ++++ fastdeploy/vision/matting/matting_pybind.cc | 4 +- .../vision/matting/ppmatting/ppmatting.cc | 5 +- fastdeploy/vision/segmentation/ppseg/model.cc | 1 - python/fastdeploy/vision/matting/__init__.py | 1 + .../fastdeploy/vision/matting/contrib/rvm.py | 81 ++++++++ tests/eval_example/test_rvm.py | 101 ++++++++++ 22 files changed, 1080 insertions(+), 16 deletions(-) mode change 100644 => 100755 examples/vision/matting/README.md mode change 100644 => 100755 examples/vision/matting/modnet/python/README.md create mode 100755 examples/vision/matting/rvm/README.md create mode 100644 examples/vision/matting/rvm/cpp/CMakeLists.txt create mode 100755 examples/vision/matting/rvm/cpp/README.md create mode 100755 examples/vision/matting/rvm/cpp/infer.cc create mode 100755 examples/vision/matting/rvm/export.md create mode 100755 examples/vision/matting/rvm/python/README.md create mode 100755 examples/vision/matting/rvm/python/infer.py mode change 100644 => 100755 fastdeploy/backends/tensorrt/trt_backend.cc mode change 100644 => 100755 fastdeploy/vision/matting/contrib/modnet.cc create mode 100755 fastdeploy/vision/matting/contrib/rvm.cc create mode 100755 fastdeploy/vision/matting/contrib/rvm.h create mode 100755 fastdeploy/vision/matting/contrib/rvm_pybind.cc mode change 100644 => 100755 fastdeploy/vision/matting/ppmatting/ppmatting.cc mode change 100644 => 100755 fastdeploy/vision/segmentation/ppseg/model.cc create mode 100755 python/fastdeploy/vision/matting/contrib/rvm.py create mode 100644 tests/eval_example/test_rvm.py diff --git a/examples/vision/matting/README.md b/examples/vision/matting/README.md old mode 100644 new mode 100755 index d4b4f5c32..f7582fcdf --- a/examples/vision/matting/README.md +++ b/examples/vision/matting/README.md @@ -5,6 +5,7 @@ FastDeploy目前支持如下抠图模型部署 | 模型 | 说明 | 模型格式 | 版本 | | :--- | :--- | :------- | :--- | | [ZHKKKe/MODNet](./modnet) | MODNet 系列模型 | ONNX | [CommitID:28165a4](https://github.com/ZHKKKe/MODNet/commit/28165a4) | +| [PeterL1n/RobustVideoMatting](./rvm) | RobustVideoMatting 系列模型 | ONNX | [CommitID:81a1093](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093) | | [PaddleSeg/PP-Matting](./ppmatting) | PP-Matting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) | | [PaddleSeg/PP-HumanMatting](./ppmatting) | PP-HumanMatting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) | | [PaddleSeg/ModNet](./ppmatting) | ModNet 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) | diff --git a/examples/vision/matting/modnet/python/README.md b/examples/vision/matting/modnet/python/README.md old mode 100644 new mode 100755 index decfef8cf..d84d95ac5 --- a/examples/vision/matting/modnet/python/README.md +++ b/examples/vision/matting/modnet/python/README.md @@ -52,16 +52,14 @@ MODNet模型加载和初始化,其中model_file为导出的ONNX模型格式 ### predict函数 > ```python -> MODNet.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5) +> MODNet.predict(image_data) > ``` > -> 模型预测结口,输入图像直接输出检测结果。 +> 模型预测结口,输入图像直接输出抠图结果。 > > **参数** > > > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式 -> > * **conf_threshold**(float): 检测框置信度过滤阈值 -> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值 > **返回** > diff --git a/examples/vision/matting/rvm/README.md b/examples/vision/matting/rvm/README.md new file mode 100755 index 000000000..16f33aae4 --- /dev/null +++ b/examples/vision/matting/rvm/README.md @@ -0,0 +1,30 @@ +# RobustVideoMatting 模型部署 + +## 模型版本说明 + +- [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093) + +## 支持模型列表 + +目前FastDeploy支持如下模型的部署 + +- [RobustVideoMatting 模型](https://github.com/PeterL1n/RobustVideoMatting) + +## 下载预训练模型 + +为了方便开发者的测试,下面提供了RobustVideoMatting导出的各系列模型,开发者可直接下载使用。 + +| 模型 | 参数大小 | 精度 | 备注 | +|:---------------------------------------------------------------- |:----- |:----- | :------ | +| [rvm_mobilenetv3_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx) | 15MB | - | +| [rvm_resnet50_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_fp32.onnx) | 103MB | - | +| [rvm_mobilenetv3_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx) | 15MB | - | +| [rvm_resnet50_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_trt.onnx) | 103MB | - | + +**Note**: +- 如果要使用 TensorRT 进行推理,需要下载后缀为 trt 的 onnx 模型文件 + +## 详细部署文档 + +- [Python部署](python) +- [C++部署](cpp) diff --git a/examples/vision/matting/rvm/cpp/CMakeLists.txt b/examples/vision/matting/rvm/cpp/CMakeLists.txt new file mode 100644 index 000000000..93540a7e8 --- /dev/null +++ b/examples/vision/matting/rvm/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/matting/rvm/cpp/README.md b/examples/vision/matting/rvm/cpp/README.md new file mode 100755 index 000000000..571e2b123 --- /dev/null +++ b/examples/vision/matting/rvm/cpp/README.md @@ -0,0 +1,87 @@ +# RobustVideoMatting C++部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上 RobustVideoMatting 推理为例,在本目录执行如下命令即可完成编译测试(如若只需在CPU上部署,可在[Fastdeploy C++预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md/CPP_prebuilt_libraries.md)下载CPU推理库) + +本目录下提供`infer.cc`快速完成RobustVideoMatting在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载SDK,编译模型examples代码(SDK中包含了examples代码) +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.3.0.tgz +tar xvf fastdeploy-linux-x64-gpu-0.3.0.tgz +cd fastdeploy-linux-x64-gpu-0.3.0/examples/vision/matting/rvm/cpp/ +mkdir build && cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.3.0 +make -j + +# 下载RobustVideoMatting模型文件和测试图片以及视频 +## 原版ONNX模型 +wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx +## 为加载TRT特殊处理ONNX模型 +wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx +wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg +wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg +wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4 + +# CPU推理 +./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 0 +# GPU推理 +./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 1 +# TRT推理 +./infer_demo rvm_mobilenetv3_trt.onnx matting_input.jpg matting_bgr.jpg 2 +``` + +运行完成可视化结果如下图所示 +
+ + + + +
+ +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## RobustVideoMatting C++接口 + +```c++ +fastdeploy::vision::matting::RobustVideoMatting( + const string& model_file, + const string& params_file = "", + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX) +``` + +RobustVideoMatting模型加载和初始化,其中model_file为导出的ONNX模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式 + +#### Predict函数 + +> ```c++ +> RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result) +> ``` +> +> 模型预测接口,输入图像直接输出抠图结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 抠图结果, MattingResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/matting/rvm/cpp/infer.cc b/examples/vision/matting/rvm/cpp/infer.cc new file mode 100755 index 000000000..9e2a2aa5d --- /dev/null +++ b/examples/vision/matting/rvm/cpp/infer.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_file, const std::string& image_file, + const std::string& background_file) { + auto option = fastdeploy::RuntimeOption(); + auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + cv::Mat bg = cv::imread(background_file); + fastdeploy::vision::MattingResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + auto vis_im = fastdeploy::vision::VisMatting(im_bak, res); + auto vis_im_with_bg = + fastdeploy::vision::SwapBackground(im_bak, bg, res); + cv::imwrite("visualized_result.jpg", vis_im_with_bg); + cv::imwrite("visualized_result_fg.jpg", vis_im); + std::cout << "Visualized result save in ./visualized_result.jpg " + "and ./visualized_result_fg.jpg" + << std::endl; +} + +void GpuInfer(const std::string& model_file, const std::string& image_file, + const std::string& background_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + cv::Mat bg = cv::imread(background_file); + fastdeploy::vision::MattingResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + auto vis_im = fastdeploy::vision::VisMatting(im_bak, res); + auto vis_im_with_bg = + fastdeploy::vision::SwapBackground(im_bak, bg, res); + cv::imwrite("visualized_result.jpg", vis_im_with_bg); + cv::imwrite("visualized_result_fg.jpg", vis_im); + std::cout << "Visualized result save in ./visualized_result_replaced_bg.jpg " + "and ./visualized_result_fg.jpg" + << std::endl; +} + +void TrtInfer(const std::string& model_file, const std::string& image_file, + const std::string& background_file) { + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + option.SetTrtInputShape("src", {1, 3, 1920, 1080}); + option.SetTrtInputShape("r1i", {1, 1, 1, 1}, {1, 16, 240, 135}, {1, 16, 240, 135}); + option.SetTrtInputShape("r2i", {1, 1, 1, 1}, {1, 20, 120, 68}, {1, 20, 120, 68}); + option.SetTrtInputShape("r3i", {1, 1, 1, 1}, {1, 40, 60, 34}, {1, 40, 60, 34}); + option.SetTrtInputShape("r4i", {1, 1, 1, 1}, {1, 64, 30, 17}, {1, 64, 30, 17}); + auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option); + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + auto im = cv::imread(image_file); + auto im_bak = im.clone(); + cv::Mat bg = cv::imread(background_file); + fastdeploy::vision::MattingResult res; + if (!model.Predict(&im, &res)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + auto vis_im = fastdeploy::vision::VisMatting(im_bak, res); + auto vis_im_with_bg = + fastdeploy::vision::SwapBackground(im_bak, bg, res); + cv::imwrite("visualized_result.jpg", vis_im_with_bg); + cv::imwrite("visualized_result_fg.jpg", vis_im); + std::cout << "Visualized result save in ./visualized_result.jpg " + "and ./visualized_result_fg.jpg" + << std::endl; +} + +int main(int argc, char* argv[]) { + if (argc < 5) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/image run_option, " + "e.g ./infer_model ./rvm_mobilenetv3_fp32.onnx ./test.jpg ./test_bg.jpg 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + if (std::atoi(argv[4]) == 0) { + CpuInfer(argv[1], argv[2], argv[3]); + } else if (std::atoi(argv[4]) == 1) { + GpuInfer(argv[1], argv[2], argv[3]); + } else if (std::atoi(argv[4]) == 2) { + TrtInfer(argv[1], argv[2], argv[3]); + } + return 0; +} diff --git a/examples/vision/matting/rvm/export.md b/examples/vision/matting/rvm/export.md new file mode 100755 index 000000000..85167754d --- /dev/null +++ b/examples/vision/matting/rvm/export.md @@ -0,0 +1,116 @@ +# RobustVideoMatting 支持TRT的动态ONNX导出 + +## 环境依赖 + +- python >= 3.5 +- pytorch 1.12.0 +- onnx 1.10.0 +- onnxsim 0.4.8 + +## 步骤一:拉取 RobustVideoMatting onnx 分支代码 + +```shell +git clone -b onnx https://github.com/PeterL1n/RobustVideoMatting.git +cd RobustVideoMatting +``` + +## 步骤二:去掉 downsample_ratio 动态输入 + +在```model/model.py```中,将 ```downsample_ratio``` 输入去掉,如下图所示 + +```python +def forward(self, src, r1, r2, r3, r4, + # downsample_ratio: float = 0.25, + segmentation_pass: bool = False): + + if torch.onnx.is_in_onnx_export(): + # src_sm = CustomOnnxResizeByFactorOp.apply(src, 0.25) + src_sm = self._interpolate(src, scale_factor=0.25) + elif downsample_ratio != 1: + src_sm = self._interpolate(src, scale_factor=0.25) + else: + src_sm = src + + f1, f2, f3, f4 = self.backbone(src_sm) + f4 = self.aspp(f4) + hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4) + + if not segmentation_pass: + fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) + # if torch.onnx.is_in_onnx_export() or downsample_ratio != 1: + if torch.onnx.is_in_onnx_export(): + fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) + fgr = fgr_residual + src + fgr = fgr.clamp(0., 1.) + pha = pha.clamp(0., 1.) + return [fgr, pha, *rec] + else: + seg = self.project_seg(hid) + return [seg, *rec] +``` + +## 步骤三:修改导出 ONNX 脚本 + +修改```export_onnx.py```脚本,去掉```downsample_ratio```输入 + +```python +def export(self): + rec = (torch.zeros([1, 1, 1, 1]).to(self.args.device, self.precision),) * 4 + # src = torch.randn(1, 3, 1080, 1920).to(self.args.device, self.precision) + src = torch.randn(1, 3, 1920, 1080).to(self.args.device, self.precision) + # downsample_ratio = torch.tensor([0.25]).to(self.args.device) + + dynamic_spatial = {0: 'batch_size', 2: 'height', 3: 'width'} + dynamic_everything = {0: 'batch_size', 1: 'channels', 2: 'height', 3: 'width'} + + torch.onnx.export( + self.model, + # (src, *rec, downsample_ratio), + (src, *rec), + self.args.output, + export_params=True, + opset_version=self.args.opset, + do_constant_folding=True, + # input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i', 'downsample_ratio'], + input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i'], + output_names=['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o'], + dynamic_axes={ + 'src': {0: 'batch_size0', 2: 'height0', 3: 'width0'}, + 'fgr': {0: 'batch_size1', 2: 'height1', 3: 'width1'}, + 'pha': {0: 'batch_size2', 2: 'height2', 3: 'width2'}, + 'r1i': {0: 'batch_size3', 1: 'channels3', 2: 'height3', 3: 'width3'}, + 'r2i': {0: 'batch_size4', 1: 'channels4', 2: 'height4', 3: 'width4'}, + 'r3i': {0: 'batch_size5', 1: 'channels5', 2: 'height5', 3: 'width5'}, + 'r4i': {0: 'batch_size6', 1: 'channels6', 2: 'height6', 3: 'width6'}, + 'r1o': {0: 'batch_size7', 2: 'height7', 3: 'width7'}, + 'r2o': {0: 'batch_size8', 2: 'height8', 3: 'width8'}, + 'r3o': {0: 'batch_size9', 2: 'height9', 3: 'width9'}, + 'r4o': {0: 'batch_size10', 2: 'height10', 3: 'width10'}, + }) +``` + +运行下列命令 + +```shell +python export_onnx.py \ + --model-variant mobilenetv3 \ + --checkpoint rvm_mobilenetv3.pth \ + --precision float32 \ + --opset 12 \ + --device cuda \ + --output rvm_mobilenetv3.onnx +``` + +**Note**: +- trt关于多输入ONNX模型的dynamic shape,如果x0和x1的shape不同,不能都以height、width去表示,要以height0、height1去区分,要不然build engine阶段会出错 + +## 步骤四:使用onnxsim简化 + +安装 onnxsim,并简化步骤三导出的 ONNX 模型 + +```shell +pip install onnxsim +onnxsim rvm_mobilenetv3.onnx rvm_mobilenetv3_trt.onnx +``` + +```rvm_mobilenetv3_trt.onnx```即为可运行 TRT 后端的动态 shape 的 ONNX 模型 diff --git a/examples/vision/matting/rvm/python/README.md b/examples/vision/matting/rvm/python/README.md new file mode 100755 index 000000000..5b3676c08 --- /dev/null +++ b/examples/vision/matting/rvm/python/README.md @@ -0,0 +1,88 @@ +# RobustVideoMatting Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成RobustVideoMatting在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/matting/rvm/python + +# 下载RobustVideoMatting模型文件和测试图片以及视频 +## 原版ONNX模型 +wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx +## 为加载TRT特殊处理ONNX模型 +wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx +wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg +wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg +wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4 + +# CPU推理 +## 图片 +python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device cpu +## 视频 +python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device cpu +# GPU推理 +## 图片 +python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu +## 视频 +python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu +# TRT推理 +## 图片 +python infer.py --model rvm_mobilenetv3_trt.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu --use_trt True +## 视频 +python infer.py --model rvm_mobilenetv3_trt.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu --use_trt True +``` + +运行完成可视化结果如下图所示 +
+ + + + + + +
+ +## RobustVideoMatting Python接口 + +```python +fd.vision.matting.RobustVideoMatting(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX) +``` + +RobustVideoMatting模型加载和初始化,其中model_file为导出的ONNX模型格式 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为ONNX + +### predict函数 + +> ```python +> RobustVideoMatting.predict(input_image) +> ``` +> +> 模型预测结口,输入图像直接输出抠图结果。 +> +> **参数** +> +> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式 + +> **返回** +> +> > 返回`fastdeploy.vision.MattingResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +## 其它文档 + +- [RobustVideoMatting 模型介绍](..) +- [RobustVideoMatting C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/matting/rvm/python/infer.py b/examples/vision/matting/rvm/python/infer.py new file mode 100755 index 000000000..11951b00f --- /dev/null +++ b/examples/vision/matting/rvm/python/infer.py @@ -0,0 +1,112 @@ +import fastdeploy as fd +import cv2 +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of RobustVideoMatting model.") + parser.add_argument("--image", type=str, help="Path of test image file.") + parser.add_argument("--video", type=str, help="Path of test video file.") + parser.add_argument( + "--bg", + type=str, + required=True, + default=None, + help="Path of test background image file.") + parser.add_argument( + '--output-composition', + type=str, + default="composition.mp4", + help="Path of composition video file.") + parser.add_argument( + '--output-alpha', + type=str, + default="alpha.mp4", + help="Path of alpha video file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + if args.device.lower() == "gpu": + option.use_gpu() + if args.use_trt: + option.use_trt_backend() + option.set_trt_input_shape("src", [1, 3, 1920, 1080]) + option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135], + [1, 16, 240, 135]) + option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68], + [1, 20, 120, 68]) + option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34], + [1, 40, 60, 34]) + option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17], + [1, 64, 30, 17]) + + return option + + +args = parse_arguments() +output_composition = args.output_composition +output_alpha = args.output_alpha + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model = fd.vision.matting.RobustVideoMatting( + args.model, runtime_option=runtime_option) +bg = cv2.imread(args.bg) + +if args.video is not None: + # for video + cap = cv2.VideoCapture(args.video) + # Define the codec and create VideoWriter object + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + composition = cv2.VideoWriter(output_composition, fourcc, 20.0, + (1080, 1920)) + alpha = cv2.VideoWriter(output_alpha, fourcc, 20.0, (1080, 1920)) + + frame_id = 0 + while True: + frame_id = frame_id + 1 + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + vis_im = fd.vision.vis_matting(frame, result) + vis_im_with_bg = fd.vision.swap_background_matting(frame, bg, result) + alpha.write(vis_im) + composition.write(vis_im_with_bg) + cv2.waitKey(30) + cap.release() + composition.release() + alpha.release() + cv2.destroyAllWindows() + print("Visualized result video save in {} and {}".format( + output_composition, output_alpha)) + +if args.image is not None: + # for image + im = cv2.imread(args.image) + result = model.predict(im.copy()) + print(result) + # 可视化结果 + vis_im = fd.vision.vis_matting(im, result) + vis_im_with_bg = fd.vision.swap_background_matting(im, bg, result) + cv2.imwrite("visualized_result_fg.jpg", vis_im) + cv2.imwrite("visualized_result_replaced_bg.jpg", vis_im_with_bg) + print( + "Visualized result save in ./visualized_result_replaced_bg.jpg and ./visualized_result_fg.jpg" + ) diff --git a/fastdeploy/backends/poros/common/iengine.h b/fastdeploy/backends/poros/common/iengine.h index 5cb49e1ee..c945621c1 100755 --- a/fastdeploy/backends/poros/common/iengine.h +++ b/fastdeploy/backends/poros/common/iengine.h @@ -27,11 +27,6 @@ namespace baidu { namespace mirana { namespace poros { -/** - * the base engine class - * every registered engine should inherit from this IEngine - **/ - struct PorosGraph { torch::jit::Graph* graph = NULL; torch::jit::Node* node = NULL; diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc old mode 100644 new mode 100755 index 363a9d1ce..100ce6f7d --- a/fastdeploy/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/backends/tensorrt/trt_backend.cc @@ -525,7 +525,7 @@ bool TrtBackend::BuildTrtEngine() { engine_->createExecutionContext()); GetInputOutputInfo(); - FDINFO << "TensorRT Engine is built succussfully." << std::endl; + FDINFO << "TensorRT Engine is built successfully." << std::endl; if (option_.serialize_file != "") { FDINFO << "Serialize TensorRTEngine to local file " << option_.serialize_file << "." << std::endl; diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 10d69c458..f69129b76 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -41,6 +41,7 @@ #include "fastdeploy/vision/faceid/contrib/vpl.h" #include "fastdeploy/vision/keypointdet/pptinypose/pptinypose.h" #include "fastdeploy/vision/matting/contrib/modnet.h" +#include "fastdeploy/vision/matting/contrib/rvm.h" #include "fastdeploy/vision/matting/ppmatting/ppmatting.h" #include "fastdeploy/vision/ocr/ppocr/classifier.h" #include "fastdeploy/vision/ocr/ppocr/dbdetector.h" diff --git a/fastdeploy/vision/matting/contrib/modnet.cc b/fastdeploy/vision/matting/contrib/modnet.cc old mode 100644 new mode 100755 index 06c1ab52f..b08266547 --- a/fastdeploy/vision/matting/contrib/modnet.cc +++ b/fastdeploy/vision/matting/contrib/modnet.cc @@ -86,7 +86,7 @@ bool MODNet::Postprocess( FDASSERT((infer_result.size() == 1), "The default number of output tensor must be 1 according to " "modnet."); - FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1) + FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w) FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch =1 now."); if (alpha_tensor.dtype != FDDataType::FP32) { FDERROR << "Only support post process with float32 data." << std::endl; diff --git a/fastdeploy/vision/matting/contrib/rvm.cc b/fastdeploy/vision/matting/contrib/rvm.cc new file mode 100755 index 000000000..04b9b9316 --- /dev/null +++ b/fastdeploy/vision/matting/contrib/rvm.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/matting/contrib/rvm.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/vision/utils/utils.h" + +namespace fastdeploy { + +namespace vision { + +namespace matting { + +RobustVideoMatting::RobustVideoMatting(const std::string& model_file, const std::string& params_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::ONNX) { + valid_cpu_backends = {Backend::OPENVINO, Backend::ORT}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + } else { + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; + } + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +bool RobustVideoMatting::Initialize() { + // parameters for preprocess + size = {1080, 1920}; + + video_mode = true; + + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +bool RobustVideoMatting::Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info) { + // Resize + int resize_w = size[0]; + int resize_h = size[1]; + if (resize_h != mat->Height() || resize_w != mat->Width()) { + Resize::Run(mat, resize_w, resize_h); + } + BGR2RGB::Run(mat); + + // Normalize + std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f}; + std::vector beta = {0.0f, 0.0f, 0.0f}; + Convert::Run(mat, alpha, beta); + // Record output shape of preprocessed image + (*im_info)["output_shape"] = {mat->Height(), mat->Width()}; + + HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + + mat->ShareWithTensor(output); + output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c + return true; +} + +bool RobustVideoMatting::Postprocess( + std::vector& infer_result, MattingResult* result, + const std::map>& im_info) { + FDASSERT((infer_result.size() == 6), + "The default number of output tensor must be 6 according to " + "RobustVideoMatting."); + FDTensor& fgr = infer_result.at(0); // fgr (1, 3, h, w) 0.~1. + FDTensor& alpha = infer_result.at(1); // alpha (1, 1, h, w) 0.~1. + FDASSERT((fgr.shape[0] == 1), "Only support batch = 1 now."); + FDASSERT((alpha.shape[0] == 1), "Only support batch = 1 now."); + if (fgr.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + if (alpha.dtype != FDDataType::FP32) { + FDERROR << "Only support post process with float32 data." << std::endl; + return false; + } + // update context + if (video_mode) { + for (size_t i = 0; i < 4; ++i) { + FDTensor& rki = infer_result.at(i+2); + dynamic_inputs_dims_[i] = rki.shape; + dynamic_inputs_datas_[i].resize(rki.Numel()); + memcpy(dynamic_inputs_datas_[i].data(), rki.Data(), + rki.Numel() * FDDataTypeSize(rki.dtype)); + } + } + + auto iter_in = im_info.find("input_shape"); + auto iter_out = im_info.find("output_shape"); + FDASSERT(iter_out != im_info.end() && iter_in != im_info.end(), + "Cannot find input_shape or output_shape from im_info."); + int out_h = iter_out->second[0]; + int out_w = iter_out->second[1]; + int in_h = iter_in->second[0]; + int in_w = iter_in->second[1]; + + // for alpha + float* alpha_ptr = static_cast(alpha.Data()); + cv::Mat alpha_zero_copy_ref(out_h, out_w, CV_32FC1, alpha_ptr); + Mat alpha_resized(alpha_zero_copy_ref); // ref-only, zero copy. + if ((out_h != in_h) || (out_w != in_w)) { + // already allocated a new continuous memory after resize. + Resize::Run(&alpha_resized, in_w, in_h, -1, -1); + } + + // for foreground + float* fgr_ptr = static_cast(fgr.Data()); + cv::Mat fgr_zero_copy_ref(out_h, out_w, CV_32FC1, fgr_ptr); + Mat fgr_resized(fgr_zero_copy_ref); // ref-only, zero copy. + if ((out_h != in_h) || (out_w != in_w)) { + // already allocated a new continuous memory after resize. + Resize::Run(&fgr_resized, in_w, in_h, -1, -1); + } + + result->Clear(); + result->contain_foreground = true; + result->shape = {static_cast(in_h), static_cast(in_w)}; + int numel = in_h * in_w; + int nbytes = numel * sizeof(float); + result->Resize(numel); + memcpy(result->alpha.data(), alpha_resized.GetOpenCVMat()->data, nbytes); + memcpy(result->foreground.data(), fgr_resized.GetOpenCVMat()->data, nbytes); + return true; +} + +bool RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result) { + Mat mat(*im); + int inputs_nums = NumInputsOfRuntime(); + std::vector input_tensors(inputs_nums); + std::map> im_info; + // Record the shape of image and the shape of preprocessed image + im_info["input_shape"] = {mat.Height(), mat.Width()}; + im_info["output_shape"] = {mat.Height(), mat.Width()}; + // convert vector to FDTensor + for (size_t i = 1; i < inputs_nums; ++i) { + input_tensors[i].SetExternalData(dynamic_inputs_dims_[i-1], FDDataType::FP32, dynamic_inputs_datas_[i-1].data()); + input_tensors[i].device = Device::CPU; + } + if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + for (size_t i = 0; i < inputs_nums; ++i) { + input_tensors[i].name = InputInfoOfRuntime(i).name; + } + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } + + if (!Postprocess(output_tensors, result, im_info)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + return true; +} + +} // namespace matting +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/matting/contrib/rvm.h b/fastdeploy/vision/matting/contrib/rvm.h new file mode 100755 index 000000000..58c64ac3b --- /dev/null +++ b/fastdeploy/vision/matting/contrib/rvm.h @@ -0,0 +1,92 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/vision/common/result.h" + +namespace fastdeploy { + +namespace vision { +/** \brief All image/video matting model APIs are defined inside this namespace + * + */ +namespace matting { + +/*! @brief RobustVideoMatting model object used when to load a RobustVideoMatting model exported by RobustVideoMatting + */ +class FASTDEPLOY_DECL RobustVideoMatting : public FastDeployModel { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx + * \param[in] params_file Path of parameter file, if the model format is ONNX, this parameter will be ignored + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is ONNX format + */ + RobustVideoMatting(const std::string& model_file, + const std::string& params_file = "", + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::ONNX); + + /// Get model's name + std::string ModelName() const { return "matting/RobustVideoMatting"; } + + /** \brief Predict the matting result for an input image + * + * \param[in] im The input image data, comes from cv::imread() + * \param[in] result The output matting result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + bool Predict(cv::Mat* im, MattingResult* result); + + /// Preprocess image size, the default is (1080, 1920) + std::vector size; + + /// Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true // NOLINT + bool video_mode; + + private: + bool Initialize(); + /// Preprocess an input image, and set the preprocessed results to `outputs` + bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info); + + /// Postprocess the inferenced results, and set the final result to `result` + bool Postprocess(std::vector& infer_result, MattingResult* result, + const std::map>& im_info); + + /// Init dynamic inputs datas + std::vector> dynamic_inputs_datas_ = { + {0.0f}, // r1i + {0.0f}, // r2i + {0.0f}, // r3i + {0.0f}, // r4i + {0.25f}, // downsample_ratio + }; + + /// Init dynamic inputs dims + std::vector> dynamic_inputs_dims_ = { + {1, 1, 1, 1}, // r1i + {1, 1, 1, 1}, // r2i + {1, 1, 1, 1}, // r3i + {1, 1, 1, 1}, // r4i + {1}, // downsample_ratio + }; +}; + +} // namespace matting +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/matting/contrib/rvm_pybind.cc b/fastdeploy/vision/matting/contrib/rvm_pybind.cc new file mode 100755 index 000000000..a45816d65 --- /dev/null +++ b/fastdeploy/vision/matting/contrib/rvm_pybind.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindRobustVideoMatting(pybind11::module& m) { + // Bind RobustVideoMatting + pybind11::class_(m, "RobustVideoMatting") + .def(pybind11::init()) + .def("predict", + [](vision::matting::RobustVideoMatting& self, pybind11::array& data) { + auto mat = PyArrayToCvMat(data); + vision::MattingResult res; + self.Predict(&mat, &res); + return res; + }) + .def_readwrite("size", &vision::matting::RobustVideoMatting::size) + .def_readwrite("video_mode", &vision::matting::RobustVideoMatting::video_mode); +} + +} // namespace fastdeploy diff --git a/fastdeploy/vision/matting/matting_pybind.cc b/fastdeploy/vision/matting/matting_pybind.cc index 8c514a428..204dd7192 100644 --- a/fastdeploy/vision/matting/matting_pybind.cc +++ b/fastdeploy/vision/matting/matting_pybind.cc @@ -17,12 +17,14 @@ namespace fastdeploy { void BindMODNet(pybind11::module& m); +void BindRobustVideoMatting(pybind11::module& m); void BindPPMatting(pybind11::module& m); void BindMatting(pybind11::module& m) { auto matting_module = - m.def_submodule("matting", "Image object matting models."); + m.def_submodule("matting", "Image/Video matting models."); BindMODNet(matting_module); + BindRobustVideoMatting(matting_module); BindPPMatting(matting_module); } } // namespace fastdeploy diff --git a/fastdeploy/vision/matting/ppmatting/ppmatting.cc b/fastdeploy/vision/matting/ppmatting/ppmatting.cc old mode 100644 new mode 100755 index 9c342d315..e760ab523 --- a/fastdeploy/vision/matting/ppmatting/ppmatting.cc +++ b/fastdeploy/vision/matting/ppmatting/ppmatting.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/vision/matting/ppmatting/ppmatting.h" -#include "fastdeploy/vision.h" #include "fastdeploy/vision/utils/utils.h" #include "yaml-cpp/yaml.h" @@ -163,8 +162,8 @@ bool PPMatting::Postprocess( const std::map>& im_info) { FDASSERT((infer_result.size() == 1), "The default number of output tensor must be 1 "); - FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1) - FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch =1 now."); + FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w) + FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch = 1 now."); if (alpha_tensor.dtype != FDDataType::FP32) { FDERROR << "Only support post process with float32 data." << std::endl; return false; diff --git a/fastdeploy/vision/segmentation/ppseg/model.cc b/fastdeploy/vision/segmentation/ppseg/model.cc old mode 100644 new mode 100755 index cd28836fb..21c377485 --- a/fastdeploy/vision/segmentation/ppseg/model.cc +++ b/fastdeploy/vision/segmentation/ppseg/model.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "fastdeploy/vision/segmentation/ppseg/model.h" -#include "fastdeploy/vision.h" #include "fastdeploy/vision/utils/utils.h" #include "yaml-cpp/yaml.h" diff --git a/python/fastdeploy/vision/matting/__init__.py b/python/fastdeploy/vision/matting/__init__.py index 9a03d81e0..43b0acbcf 100644 --- a/python/fastdeploy/vision/matting/__init__.py +++ b/python/fastdeploy/vision/matting/__init__.py @@ -14,4 +14,5 @@ from __future__ import absolute_import from .contrib.modnet import MODNet +from .contrib.rvm import RobustVideoMatting from .ppmatting import PPMatting diff --git a/python/fastdeploy/vision/matting/contrib/rvm.py b/python/fastdeploy/vision/matting/contrib/rvm.py new file mode 100755 index 000000000..144a3823c --- /dev/null +++ b/python/fastdeploy/vision/matting/contrib/rvm.py @@ -0,0 +1,81 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import logging +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class RobustVideoMatting(FastDeployModel): + def __init__(self, + model_file, + params_file="", + runtime_option=None, + model_format=ModelFormat.ONNX): + """Load a video matting model exported by RobustVideoMatting. + + :param model_file: (str)Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx + :param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX + """ + super(RobustVideoMatting, self).__init__(runtime_option) + + self._model = C.vision.matting.RobustVideoMatting( + model_file, params_file, self._runtime_option, model_format) + assert self.initialized, "RobustVideoMatting initialize failed." + + def predict(self, input_image): + """Matting an input image + + :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: MattingResult + """ + return self._model.predict(input_image) + + @property + def size(self): + """ + Returns the preprocess image size + """ + return self._model.size + + @property + def video_mode(self): + """ + Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true + """ + return self._model.video_mode + + @size.setter + def size(self, wh): + """ + Set the preprocess image size + """ + assert isinstance(wh, (list, tuple)),\ + "The value to set `size` must be type of tuple or list." + assert len(wh) == 2,\ + "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format( + len(wh)) + self._model.size = wh + + @video_mode.setter + def video_mode(self, value): + """ + Set video_mode property, the default is true + """ + assert isinstance( + value, bool), "The value to set `video_mode` must be type of bool." + self._model.video_mode = value diff --git a/tests/eval_example/test_rvm.py b/tests/eval_example/test_rvm.py new file mode 100644 index 000000000..4b8d5afe8 --- /dev/null +++ b/tests/eval_example/test_rvm.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import pickle +import numpy as np + + +def test_matting_rvm_cpu(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz" + input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4" + fd.download_and_decompress(model_url, ".") + fd.download(input_url, ".") + model_path = "rvm/rvm_mobilenetv3_fp32.onnx" + # use ORT + runtime_option = fd.RuntimeOption() + runtime_option.use_ort_backend() + model = fd.vision.matting.RobustVideoMatting( + model_path, runtime_option=runtime_option) + + cap = cv2.VideoCapture(input_url) + + frame_id = 0 + while True: + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + # compare diff + expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy") + result_alpha = np.array(result.alpha).reshape(1920, 1080) + diff = np.fabs(expect_alpha - result_alpha) + thres = 1e-05 + assert diff.max( + ) < thres, "The label diff is %f, which is bigger than %f" % ( + diff.max(), thres) + frame_id = frame_id + 1 + cv2.waitKey(30) + if frame_id >= 10: + cap.release() + cv2.destroyAllWindows() + break + + +def test_matting_rvm_gpu_trt(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz" + input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4" + fd.download_and_decompress(model_url, ".") + fd.download(input_url, ".") + model_path = "rvm/rvm_mobilenetv3_trt.onnx" + # use TRT + runtime_option = fd.RuntimeOption() + runtime_option.use_gpu() + runtime_option.use_trt_backend() + runtime_option.set_trt_input_shape("src", [1, 3, 1920, 1080]) + runtime_option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135], + [1, 16, 240, 135]) + runtime_option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68], + [1, 20, 120, 68]) + runtime_option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34], + [1, 40, 60, 34]) + runtime_option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17], + [1, 64, 30, 17]) + model = fd.vision.matting.RobustVideoMatting( + model_path, runtime_option=runtime_option) + + cap = cv2.VideoCapture("./video.mp4") + + frame_id = 0 + while True: + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + # compare diff + expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy") + result_alpha = np.array(result.alpha).reshape(1920, 1080) + diff = np.fabs(expect_alpha - result_alpha) + thres = 1e-04 + assert diff.max( + ) < thres, "The label diff is %f, which is bigger than %f" % ( + diff.max(), thres) + frame_id = frame_id + 1 + cv2.waitKey(30) + if frame_id >= 10: + cap.release() + cv2.destroyAllWindows() + break