diff --git a/examples/vision/matting/README.md b/examples/vision/matting/README.md
old mode 100644
new mode 100755
index d4b4f5c32..f7582fcdf
--- a/examples/vision/matting/README.md
+++ b/examples/vision/matting/README.md
@@ -5,6 +5,7 @@ FastDeploy目前支持如下抠图模型部署
| 模型 | 说明 | 模型格式 | 版本 |
| :--- | :--- | :------- | :--- |
| [ZHKKKe/MODNet](./modnet) | MODNet 系列模型 | ONNX | [CommitID:28165a4](https://github.com/ZHKKKe/MODNet/commit/28165a4) |
+| [PeterL1n/RobustVideoMatting](./rvm) | RobustVideoMatting 系列模型 | ONNX | [CommitID:81a1093](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093) |
| [PaddleSeg/PP-Matting](./ppmatting) | PP-Matting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |
| [PaddleSeg/PP-HumanMatting](./ppmatting) | PP-HumanMatting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |
| [PaddleSeg/ModNet](./ppmatting) | ModNet 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |
diff --git a/examples/vision/matting/modnet/python/README.md b/examples/vision/matting/modnet/python/README.md
old mode 100644
new mode 100755
index decfef8cf..d84d95ac5
--- a/examples/vision/matting/modnet/python/README.md
+++ b/examples/vision/matting/modnet/python/README.md
@@ -52,16 +52,14 @@ MODNet模型加载和初始化,其中model_file为导出的ONNX模型格式
### predict函数
> ```python
-> MODNet.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
+> MODNet.predict(image_data)
> ```
>
-> 模型预测结口,输入图像直接输出检测结果。
+> 模型预测结口,输入图像直接输出抠图结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据,注意需为HWC,BGR格式
-> > * **conf_threshold**(float): 检测框置信度过滤阈值
-> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
> **返回**
>
diff --git a/examples/vision/matting/rvm/README.md b/examples/vision/matting/rvm/README.md
new file mode 100755
index 000000000..16f33aae4
--- /dev/null
+++ b/examples/vision/matting/rvm/README.md
@@ -0,0 +1,30 @@
+# RobustVideoMatting 模型部署
+
+## 模型版本说明
+
+- [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)
+
+## 支持模型列表
+
+目前FastDeploy支持如下模型的部署
+
+- [RobustVideoMatting 模型](https://github.com/PeterL1n/RobustVideoMatting)
+
+## 下载预训练模型
+
+为了方便开发者的测试,下面提供了RobustVideoMatting导出的各系列模型,开发者可直接下载使用。
+
+| 模型 | 参数大小 | 精度 | 备注 |
+|:---------------------------------------------------------------- |:----- |:----- | :------ |
+| [rvm_mobilenetv3_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx) | 15MB | - |
+| [rvm_resnet50_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_fp32.onnx) | 103MB | - |
+| [rvm_mobilenetv3_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx) | 15MB | - |
+| [rvm_resnet50_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_trt.onnx) | 103MB | - |
+
+**Note**:
+- 如果要使用 TensorRT 进行推理,需要下载后缀为 trt 的 onnx 模型文件
+
+## 详细部署文档
+
+- [Python部署](python)
+- [C++部署](cpp)
diff --git a/examples/vision/matting/rvm/cpp/CMakeLists.txt b/examples/vision/matting/rvm/cpp/CMakeLists.txt
new file mode 100644
index 000000000..93540a7e8
--- /dev/null
+++ b/examples/vision/matting/rvm/cpp/CMakeLists.txt
@@ -0,0 +1,14 @@
+PROJECT(infer_demo C CXX)
+CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
+
+# 指定下载解压后的fastdeploy库路径
+option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
+
+include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
+
+# 添加FastDeploy依赖头文件
+include_directories(${FASTDEPLOY_INCS})
+
+add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
+# 添加FastDeploy库依赖
+target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})
diff --git a/examples/vision/matting/rvm/cpp/README.md b/examples/vision/matting/rvm/cpp/README.md
new file mode 100755
index 000000000..571e2b123
--- /dev/null
+++ b/examples/vision/matting/rvm/cpp/README.md
@@ -0,0 +1,87 @@
+# RobustVideoMatting C++部署示例
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+
+以Linux上 RobustVideoMatting 推理为例,在本目录执行如下命令即可完成编译测试(如若只需在CPU上部署,可在[Fastdeploy C++预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md/CPP_prebuilt_libraries.md)下载CPU推理库)
+
+本目录下提供`infer.cc`快速完成RobustVideoMatting在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
+
+```bash
+#下载SDK,编译模型examples代码(SDK中包含了examples代码)
+wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.3.0.tgz
+tar xvf fastdeploy-linux-x64-gpu-0.3.0.tgz
+cd fastdeploy-linux-x64-gpu-0.3.0/examples/vision/matting/rvm/cpp/
+mkdir build && cd build
+cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.3.0
+make -j
+
+# 下载RobustVideoMatting模型文件和测试图片以及视频
+## 原版ONNX模型
+wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx
+## 为加载TRT特殊处理ONNX模型
+wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx
+wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4
+
+# CPU推理
+./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 0
+# GPU推理
+./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 1
+# TRT推理
+./infer_demo rvm_mobilenetv3_trt.onnx matting_input.jpg matting_bgr.jpg 2
+```
+
+运行完成可视化结果如下图所示
+
+
+以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
+- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
+
+## RobustVideoMatting C++接口
+
+```c++
+fastdeploy::vision::matting::RobustVideoMatting(
+ const string& model_file,
+ const string& params_file = "",
+ const RuntimeOption& runtime_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::ONNX)
+```
+
+RobustVideoMatting模型加载和初始化,其中model_file为导出的ONNX模型格式。
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为ONNX格式
+
+#### Predict函数
+
+> ```c++
+> RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result)
+> ```
+>
+> 模型预测接口,输入图像直接输出抠图结果。
+>
+> **参数**
+>
+> > * **im**: 输入图像,注意需为HWC,BGR格式
+> > * **result**: 抠图结果, MattingResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
+
+
+## 其它文档
+
+- [模型介绍](../../)
+- [Python部署](../python)
+- [视觉模型预测结果](../../../../../docs/api/vision_results/)
+- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
diff --git a/examples/vision/matting/rvm/cpp/infer.cc b/examples/vision/matting/rvm/cpp/infer.cc
new file mode 100755
index 000000000..9e2a2aa5d
--- /dev/null
+++ b/examples/vision/matting/rvm/cpp/infer.cc
@@ -0,0 +1,131 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision.h"
+
+#ifdef WIN32
+const char sep = '\\';
+#else
+const char sep = '/';
+#endif
+
+void CpuInfer(const std::string& model_file, const std::string& image_file,
+ const std::string& background_file) {
+ auto option = fastdeploy::RuntimeOption();
+ auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+ auto im_bak = im.clone();
+ cv::Mat bg = cv::imread(background_file);
+ fastdeploy::vision::MattingResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
+ auto vis_im_with_bg =
+ fastdeploy::vision::SwapBackground(im_bak, bg, res);
+ cv::imwrite("visualized_result.jpg", vis_im_with_bg);
+ cv::imwrite("visualized_result_fg.jpg", vis_im);
+ std::cout << "Visualized result save in ./visualized_result.jpg "
+ "and ./visualized_result_fg.jpg"
+ << std::endl;
+}
+
+void GpuInfer(const std::string& model_file, const std::string& image_file,
+ const std::string& background_file) {
+ auto option = fastdeploy::RuntimeOption();
+ option.UseGpu();
+ auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+ auto im_bak = im.clone();
+ cv::Mat bg = cv::imread(background_file);
+ fastdeploy::vision::MattingResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
+ auto vis_im_with_bg =
+ fastdeploy::vision::SwapBackground(im_bak, bg, res);
+ cv::imwrite("visualized_result.jpg", vis_im_with_bg);
+ cv::imwrite("visualized_result_fg.jpg", vis_im);
+ std::cout << "Visualized result save in ./visualized_result_replaced_bg.jpg "
+ "and ./visualized_result_fg.jpg"
+ << std::endl;
+}
+
+void TrtInfer(const std::string& model_file, const std::string& image_file,
+ const std::string& background_file) {
+ auto option = fastdeploy::RuntimeOption();
+ option.UseGpu();
+ option.UseTrtBackend();
+ option.SetTrtInputShape("src", {1, 3, 1920, 1080});
+ option.SetTrtInputShape("r1i", {1, 1, 1, 1}, {1, 16, 240, 135}, {1, 16, 240, 135});
+ option.SetTrtInputShape("r2i", {1, 1, 1, 1}, {1, 20, 120, 68}, {1, 20, 120, 68});
+ option.SetTrtInputShape("r3i", {1, 1, 1, 1}, {1, 40, 60, 34}, {1, 40, 60, 34});
+ option.SetTrtInputShape("r4i", {1, 1, 1, 1}, {1, 64, 30, 17}, {1, 64, 30, 17});
+ auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
+ if (!model.Initialized()) {
+ std::cerr << "Failed to initialize." << std::endl;
+ return;
+ }
+
+ auto im = cv::imread(image_file);
+ auto im_bak = im.clone();
+ cv::Mat bg = cv::imread(background_file);
+ fastdeploy::vision::MattingResult res;
+ if (!model.Predict(&im, &res)) {
+ std::cerr << "Failed to predict." << std::endl;
+ return;
+ }
+ auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
+ auto vis_im_with_bg =
+ fastdeploy::vision::SwapBackground(im_bak, bg, res);
+ cv::imwrite("visualized_result.jpg", vis_im_with_bg);
+ cv::imwrite("visualized_result_fg.jpg", vis_im);
+ std::cout << "Visualized result save in ./visualized_result.jpg "
+ "and ./visualized_result_fg.jpg"
+ << std::endl;
+}
+
+int main(int argc, char* argv[]) {
+ if (argc < 5) {
+ std::cout
+ << "Usage: infer_demo path/to/model_dir path/to/image run_option, "
+ "e.g ./infer_model ./rvm_mobilenetv3_fp32.onnx ./test.jpg ./test_bg.jpg 0"
+ << std::endl;
+ std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
+ "with gpu; 2: run with gpu and use tensorrt backend."
+ << std::endl;
+ return -1;
+ }
+ if (std::atoi(argv[4]) == 0) {
+ CpuInfer(argv[1], argv[2], argv[3]);
+ } else if (std::atoi(argv[4]) == 1) {
+ GpuInfer(argv[1], argv[2], argv[3]);
+ } else if (std::atoi(argv[4]) == 2) {
+ TrtInfer(argv[1], argv[2], argv[3]);
+ }
+ return 0;
+}
diff --git a/examples/vision/matting/rvm/export.md b/examples/vision/matting/rvm/export.md
new file mode 100755
index 000000000..85167754d
--- /dev/null
+++ b/examples/vision/matting/rvm/export.md
@@ -0,0 +1,116 @@
+# RobustVideoMatting 支持TRT的动态ONNX导出
+
+## 环境依赖
+
+- python >= 3.5
+- pytorch 1.12.0
+- onnx 1.10.0
+- onnxsim 0.4.8
+
+## 步骤一:拉取 RobustVideoMatting onnx 分支代码
+
+```shell
+git clone -b onnx https://github.com/PeterL1n/RobustVideoMatting.git
+cd RobustVideoMatting
+```
+
+## 步骤二:去掉 downsample_ratio 动态输入
+
+在```model/model.py```中,将 ```downsample_ratio``` 输入去掉,如下图所示
+
+```python
+def forward(self, src, r1, r2, r3, r4,
+ # downsample_ratio: float = 0.25,
+ segmentation_pass: bool = False):
+
+ if torch.onnx.is_in_onnx_export():
+ # src_sm = CustomOnnxResizeByFactorOp.apply(src, 0.25)
+ src_sm = self._interpolate(src, scale_factor=0.25)
+ elif downsample_ratio != 1:
+ src_sm = self._interpolate(src, scale_factor=0.25)
+ else:
+ src_sm = src
+
+ f1, f2, f3, f4 = self.backbone(src_sm)
+ f4 = self.aspp(f4)
+ hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
+
+ if not segmentation_pass:
+ fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
+ # if torch.onnx.is_in_onnx_export() or downsample_ratio != 1:
+ if torch.onnx.is_in_onnx_export():
+ fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
+ fgr = fgr_residual + src
+ fgr = fgr.clamp(0., 1.)
+ pha = pha.clamp(0., 1.)
+ return [fgr, pha, *rec]
+ else:
+ seg = self.project_seg(hid)
+ return [seg, *rec]
+```
+
+## 步骤三:修改导出 ONNX 脚本
+
+修改```export_onnx.py```脚本,去掉```downsample_ratio```输入
+
+```python
+def export(self):
+ rec = (torch.zeros([1, 1, 1, 1]).to(self.args.device, self.precision),) * 4
+ # src = torch.randn(1, 3, 1080, 1920).to(self.args.device, self.precision)
+ src = torch.randn(1, 3, 1920, 1080).to(self.args.device, self.precision)
+ # downsample_ratio = torch.tensor([0.25]).to(self.args.device)
+
+ dynamic_spatial = {0: 'batch_size', 2: 'height', 3: 'width'}
+ dynamic_everything = {0: 'batch_size', 1: 'channels', 2: 'height', 3: 'width'}
+
+ torch.onnx.export(
+ self.model,
+ # (src, *rec, downsample_ratio),
+ (src, *rec),
+ self.args.output,
+ export_params=True,
+ opset_version=self.args.opset,
+ do_constant_folding=True,
+ # input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i', 'downsample_ratio'],
+ input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i'],
+ output_names=['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o'],
+ dynamic_axes={
+ 'src': {0: 'batch_size0', 2: 'height0', 3: 'width0'},
+ 'fgr': {0: 'batch_size1', 2: 'height1', 3: 'width1'},
+ 'pha': {0: 'batch_size2', 2: 'height2', 3: 'width2'},
+ 'r1i': {0: 'batch_size3', 1: 'channels3', 2: 'height3', 3: 'width3'},
+ 'r2i': {0: 'batch_size4', 1: 'channels4', 2: 'height4', 3: 'width4'},
+ 'r3i': {0: 'batch_size5', 1: 'channels5', 2: 'height5', 3: 'width5'},
+ 'r4i': {0: 'batch_size6', 1: 'channels6', 2: 'height6', 3: 'width6'},
+ 'r1o': {0: 'batch_size7', 2: 'height7', 3: 'width7'},
+ 'r2o': {0: 'batch_size8', 2: 'height8', 3: 'width8'},
+ 'r3o': {0: 'batch_size9', 2: 'height9', 3: 'width9'},
+ 'r4o': {0: 'batch_size10', 2: 'height10', 3: 'width10'},
+ })
+```
+
+运行下列命令
+
+```shell
+python export_onnx.py \
+ --model-variant mobilenetv3 \
+ --checkpoint rvm_mobilenetv3.pth \
+ --precision float32 \
+ --opset 12 \
+ --device cuda \
+ --output rvm_mobilenetv3.onnx
+```
+
+**Note**:
+- trt关于多输入ONNX模型的dynamic shape,如果x0和x1的shape不同,不能都以height、width去表示,要以height0、height1去区分,要不然build engine阶段会出错
+
+## 步骤四:使用onnxsim简化
+
+安装 onnxsim,并简化步骤三导出的 ONNX 模型
+
+```shell
+pip install onnxsim
+onnxsim rvm_mobilenetv3.onnx rvm_mobilenetv3_trt.onnx
+```
+
+```rvm_mobilenetv3_trt.onnx```即为可运行 TRT 后端的动态 shape 的 ONNX 模型
diff --git a/examples/vision/matting/rvm/python/README.md b/examples/vision/matting/rvm/python/README.md
new file mode 100755
index 000000000..5b3676c08
--- /dev/null
+++ b/examples/vision/matting/rvm/python/README.md
@@ -0,0 +1,88 @@
+# RobustVideoMatting Python部署示例
+
+在部署前,需确认以下两个步骤
+
+- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
+
+本目录下提供`infer.py`快速完成RobustVideoMatting在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
+
+```bash
+#下载部署示例代码
+git clone https://github.com/PaddlePaddle/FastDeploy.git
+cd FastDeploy/examples/vision/matting/rvm/python
+
+# 下载RobustVideoMatting模型文件和测试图片以及视频
+## 原版ONNX模型
+wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx
+## 为加载TRT特殊处理ONNX模型
+wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx
+wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg
+wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4
+
+# CPU推理
+## 图片
+python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device cpu
+## 视频
+python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device cpu
+# GPU推理
+## 图片
+python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu
+## 视频
+python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu
+# TRT推理
+## 图片
+python infer.py --model rvm_mobilenetv3_trt.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu --use_trt True
+## 视频
+python infer.py --model rvm_mobilenetv3_trt.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu --use_trt True
+```
+
+运行完成可视化结果如下图所示
+
+
+## RobustVideoMatting Python接口
+
+```python
+fd.vision.matting.RobustVideoMatting(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX)
+```
+
+RobustVideoMatting模型加载和初始化,其中model_file为导出的ONNX模型格式
+
+**参数**
+
+> * **model_file**(str): 模型文件路径
+> * **params_file**(str): 参数文件路径,当模型格式为ONNX格式时,此参数无需设定
+> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
+> * **model_format**(ModelFormat): 模型格式,默认为ONNX
+
+### predict函数
+
+> ```python
+> RobustVideoMatting.predict(input_image)
+> ```
+>
+> 模型预测结口,输入图像直接输出抠图结果。
+>
+> **参数**
+>
+> > * **input_image**(np.ndarray): 输入数据,注意需为HWC,BGR格式
+
+> **返回**
+>
+> > 返回`fastdeploy.vision.MattingResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
+
+
+## 其它文档
+
+- [RobustVideoMatting 模型介绍](..)
+- [RobustVideoMatting C++部署](../cpp)
+- [模型预测结果说明](../../../../../docs/api/vision_results/)
+- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)
diff --git a/examples/vision/matting/rvm/python/infer.py b/examples/vision/matting/rvm/python/infer.py
new file mode 100755
index 000000000..11951b00f
--- /dev/null
+++ b/examples/vision/matting/rvm/python/infer.py
@@ -0,0 +1,112 @@
+import fastdeploy as fd
+import cv2
+import os
+
+
+def parse_arguments():
+ import argparse
+ import ast
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model", required=True, help="Path of RobustVideoMatting model.")
+ parser.add_argument("--image", type=str, help="Path of test image file.")
+ parser.add_argument("--video", type=str, help="Path of test video file.")
+ parser.add_argument(
+ "--bg",
+ type=str,
+ required=True,
+ default=None,
+ help="Path of test background image file.")
+ parser.add_argument(
+ '--output-composition',
+ type=str,
+ default="composition.mp4",
+ help="Path of composition video file.")
+ parser.add_argument(
+ '--output-alpha',
+ type=str,
+ default="alpha.mp4",
+ help="Path of alpha video file.")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default='cpu',
+ help="Type of inference device, support 'cpu' or 'gpu'.")
+ parser.add_argument(
+ "--use_trt",
+ type=ast.literal_eval,
+ default=False,
+ help="Wether to use tensorrt.")
+ return parser.parse_args()
+
+
+def build_option(args):
+ option = fd.RuntimeOption()
+ if args.device.lower() == "gpu":
+ option.use_gpu()
+ if args.use_trt:
+ option.use_trt_backend()
+ option.set_trt_input_shape("src", [1, 3, 1920, 1080])
+ option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135],
+ [1, 16, 240, 135])
+ option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68],
+ [1, 20, 120, 68])
+ option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34],
+ [1, 40, 60, 34])
+ option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17],
+ [1, 64, 30, 17])
+
+ return option
+
+
+args = parse_arguments()
+output_composition = args.output_composition
+output_alpha = args.output_alpha
+
+# 配置runtime,加载模型
+runtime_option = build_option(args)
+model = fd.vision.matting.RobustVideoMatting(
+ args.model, runtime_option=runtime_option)
+bg = cv2.imread(args.bg)
+
+if args.video is not None:
+ # for video
+ cap = cv2.VideoCapture(args.video)
+ # Define the codec and create VideoWriter object
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ composition = cv2.VideoWriter(output_composition, fourcc, 20.0,
+ (1080, 1920))
+ alpha = cv2.VideoWriter(output_alpha, fourcc, 20.0, (1080, 1920))
+
+ frame_id = 0
+ while True:
+ frame_id = frame_id + 1
+ _, frame = cap.read()
+ if frame is None:
+ break
+ result = model.predict(frame)
+ vis_im = fd.vision.vis_matting(frame, result)
+ vis_im_with_bg = fd.vision.swap_background_matting(frame, bg, result)
+ alpha.write(vis_im)
+ composition.write(vis_im_with_bg)
+ cv2.waitKey(30)
+ cap.release()
+ composition.release()
+ alpha.release()
+ cv2.destroyAllWindows()
+ print("Visualized result video save in {} and {}".format(
+ output_composition, output_alpha))
+
+if args.image is not None:
+ # for image
+ im = cv2.imread(args.image)
+ result = model.predict(im.copy())
+ print(result)
+ # 可视化结果
+ vis_im = fd.vision.vis_matting(im, result)
+ vis_im_with_bg = fd.vision.swap_background_matting(im, bg, result)
+ cv2.imwrite("visualized_result_fg.jpg", vis_im)
+ cv2.imwrite("visualized_result_replaced_bg.jpg", vis_im_with_bg)
+ print(
+ "Visualized result save in ./visualized_result_replaced_bg.jpg and ./visualized_result_fg.jpg"
+ )
diff --git a/fastdeploy/backends/poros/common/iengine.h b/fastdeploy/backends/poros/common/iengine.h
index 5cb49e1ee..c945621c1 100755
--- a/fastdeploy/backends/poros/common/iengine.h
+++ b/fastdeploy/backends/poros/common/iengine.h
@@ -27,11 +27,6 @@ namespace baidu {
namespace mirana {
namespace poros {
-/**
- * the base engine class
- * every registered engine should inherit from this IEngine
- **/
-
struct PorosGraph {
torch::jit::Graph* graph = NULL;
torch::jit::Node* node = NULL;
diff --git a/fastdeploy/backends/tensorrt/trt_backend.cc b/fastdeploy/backends/tensorrt/trt_backend.cc
old mode 100644
new mode 100755
index 363a9d1ce..100ce6f7d
--- a/fastdeploy/backends/tensorrt/trt_backend.cc
+++ b/fastdeploy/backends/tensorrt/trt_backend.cc
@@ -525,7 +525,7 @@ bool TrtBackend::BuildTrtEngine() {
engine_->createExecutionContext());
GetInputOutputInfo();
- FDINFO << "TensorRT Engine is built succussfully." << std::endl;
+ FDINFO << "TensorRT Engine is built successfully." << std::endl;
if (option_.serialize_file != "") {
FDINFO << "Serialize TensorRTEngine to local file "
<< option_.serialize_file << "." << std::endl;
diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h
index 10d69c458..f69129b76 100755
--- a/fastdeploy/vision.h
+++ b/fastdeploy/vision.h
@@ -41,6 +41,7 @@
#include "fastdeploy/vision/faceid/contrib/vpl.h"
#include "fastdeploy/vision/keypointdet/pptinypose/pptinypose.h"
#include "fastdeploy/vision/matting/contrib/modnet.h"
+#include "fastdeploy/vision/matting/contrib/rvm.h"
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"
diff --git a/fastdeploy/vision/matting/contrib/modnet.cc b/fastdeploy/vision/matting/contrib/modnet.cc
old mode 100644
new mode 100755
index 06c1ab52f..b08266547
--- a/fastdeploy/vision/matting/contrib/modnet.cc
+++ b/fastdeploy/vision/matting/contrib/modnet.cc
@@ -86,7 +86,7 @@ bool MODNet::Postprocess(
FDASSERT((infer_result.size() == 1),
"The default number of output tensor must be 1 according to "
"modnet.");
- FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1)
+ FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w)
FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch =1 now.");
if (alpha_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
diff --git a/fastdeploy/vision/matting/contrib/rvm.cc b/fastdeploy/vision/matting/contrib/rvm.cc
new file mode 100755
index 000000000..04b9b9316
--- /dev/null
+++ b/fastdeploy/vision/matting/contrib/rvm.cc
@@ -0,0 +1,182 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/vision/matting/contrib/rvm.h"
+#include "fastdeploy/utils/perf.h"
+#include "fastdeploy/vision/utils/utils.h"
+
+namespace fastdeploy {
+
+namespace vision {
+
+namespace matting {
+
+RobustVideoMatting::RobustVideoMatting(const std::string& model_file, const std::string& params_file,
+ const RuntimeOption& custom_option,
+ const ModelFormat& model_format) {
+ if (model_format == ModelFormat::ONNX) {
+ valid_cpu_backends = {Backend::OPENVINO, Backend::ORT};
+ valid_gpu_backends = {Backend::ORT, Backend::TRT};
+ } else {
+ valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
+ valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
+ }
+ runtime_option = custom_option;
+ runtime_option.model_format = model_format;
+ runtime_option.model_file = model_file;
+ runtime_option.params_file = params_file;
+ initialized = Initialize();
+}
+
+bool RobustVideoMatting::Initialize() {
+ // parameters for preprocess
+ size = {1080, 1920};
+
+ video_mode = true;
+
+ if (!InitRuntime()) {
+ FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
+ return false;
+ }
+ return true;
+}
+
+bool RobustVideoMatting::Preprocess(Mat* mat, FDTensor* output,
+ std::map>* im_info) {
+ // Resize
+ int resize_w = size[0];
+ int resize_h = size[1];
+ if (resize_h != mat->Height() || resize_w != mat->Width()) {
+ Resize::Run(mat, resize_w, resize_h);
+ }
+ BGR2RGB::Run(mat);
+
+ // Normalize
+ std::vector alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
+ std::vector beta = {0.0f, 0.0f, 0.0f};
+ Convert::Run(mat, alpha, beta);
+ // Record output shape of preprocessed image
+ (*im_info)["output_shape"] = {mat->Height(), mat->Width()};
+
+ HWC2CHW::Run(mat);
+ Cast::Run(mat, "float");
+
+ mat->ShareWithTensor(output);
+ output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
+ return true;
+}
+
+bool RobustVideoMatting::Postprocess(
+ std::vector& infer_result, MattingResult* result,
+ const std::map>& im_info) {
+ FDASSERT((infer_result.size() == 6),
+ "The default number of output tensor must be 6 according to "
+ "RobustVideoMatting.");
+ FDTensor& fgr = infer_result.at(0); // fgr (1, 3, h, w) 0.~1.
+ FDTensor& alpha = infer_result.at(1); // alpha (1, 1, h, w) 0.~1.
+ FDASSERT((fgr.shape[0] == 1), "Only support batch = 1 now.");
+ FDASSERT((alpha.shape[0] == 1), "Only support batch = 1 now.");
+ if (fgr.dtype != FDDataType::FP32) {
+ FDERROR << "Only support post process with float32 data." << std::endl;
+ return false;
+ }
+ if (alpha.dtype != FDDataType::FP32) {
+ FDERROR << "Only support post process with float32 data." << std::endl;
+ return false;
+ }
+ // update context
+ if (video_mode) {
+ for (size_t i = 0; i < 4; ++i) {
+ FDTensor& rki = infer_result.at(i+2);
+ dynamic_inputs_dims_[i] = rki.shape;
+ dynamic_inputs_datas_[i].resize(rki.Numel());
+ memcpy(dynamic_inputs_datas_[i].data(), rki.Data(),
+ rki.Numel() * FDDataTypeSize(rki.dtype));
+ }
+ }
+
+ auto iter_in = im_info.find("input_shape");
+ auto iter_out = im_info.find("output_shape");
+ FDASSERT(iter_out != im_info.end() && iter_in != im_info.end(),
+ "Cannot find input_shape or output_shape from im_info.");
+ int out_h = iter_out->second[0];
+ int out_w = iter_out->second[1];
+ int in_h = iter_in->second[0];
+ int in_w = iter_in->second[1];
+
+ // for alpha
+ float* alpha_ptr = static_cast(alpha.Data());
+ cv::Mat alpha_zero_copy_ref(out_h, out_w, CV_32FC1, alpha_ptr);
+ Mat alpha_resized(alpha_zero_copy_ref); // ref-only, zero copy.
+ if ((out_h != in_h) || (out_w != in_w)) {
+ // already allocated a new continuous memory after resize.
+ Resize::Run(&alpha_resized, in_w, in_h, -1, -1);
+ }
+
+ // for foreground
+ float* fgr_ptr = static_cast(fgr.Data());
+ cv::Mat fgr_zero_copy_ref(out_h, out_w, CV_32FC1, fgr_ptr);
+ Mat fgr_resized(fgr_zero_copy_ref); // ref-only, zero copy.
+ if ((out_h != in_h) || (out_w != in_w)) {
+ // already allocated a new continuous memory after resize.
+ Resize::Run(&fgr_resized, in_w, in_h, -1, -1);
+ }
+
+ result->Clear();
+ result->contain_foreground = true;
+ result->shape = {static_cast(in_h), static_cast(in_w)};
+ int numel = in_h * in_w;
+ int nbytes = numel * sizeof(float);
+ result->Resize(numel);
+ memcpy(result->alpha.data(), alpha_resized.GetOpenCVMat()->data, nbytes);
+ memcpy(result->foreground.data(), fgr_resized.GetOpenCVMat()->data, nbytes);
+ return true;
+}
+
+bool RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result) {
+ Mat mat(*im);
+ int inputs_nums = NumInputsOfRuntime();
+ std::vector input_tensors(inputs_nums);
+ std::map> im_info;
+ // Record the shape of image and the shape of preprocessed image
+ im_info["input_shape"] = {mat.Height(), mat.Width()};
+ im_info["output_shape"] = {mat.Height(), mat.Width()};
+ // convert vector to FDTensor
+ for (size_t i = 1; i < inputs_nums; ++i) {
+ input_tensors[i].SetExternalData(dynamic_inputs_dims_[i-1], FDDataType::FP32, dynamic_inputs_datas_[i-1].data());
+ input_tensors[i].device = Device::CPU;
+ }
+ if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
+ FDERROR << "Failed to preprocess input image." << std::endl;
+ return false;
+ }
+ for (size_t i = 0; i < inputs_nums; ++i) {
+ input_tensors[i].name = InputInfoOfRuntime(i).name;
+ }
+ std::vector output_tensors;
+ if (!Infer(input_tensors, &output_tensors)) {
+ FDERROR << "Failed to inference." << std::endl;
+ return false;
+ }
+
+ if (!Postprocess(output_tensors, result, im_info)) {
+ FDERROR << "Failed to post process." << std::endl;
+ return false;
+ }
+ return true;
+}
+
+} // namespace matting
+} // namespace vision
+} // namespace fastdeploy
\ No newline at end of file
diff --git a/fastdeploy/vision/matting/contrib/rvm.h b/fastdeploy/vision/matting/contrib/rvm.h
new file mode 100755
index 000000000..58c64ac3b
--- /dev/null
+++ b/fastdeploy/vision/matting/contrib/rvm.h
@@ -0,0 +1,92 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include "fastdeploy/fastdeploy_model.h"
+#include "fastdeploy/vision/common/processors/transform.h"
+#include "fastdeploy/vision/common/result.h"
+
+namespace fastdeploy {
+
+namespace vision {
+/** \brief All image/video matting model APIs are defined inside this namespace
+ *
+ */
+namespace matting {
+
+/*! @brief RobustVideoMatting model object used when to load a RobustVideoMatting model exported by RobustVideoMatting
+ */
+class FASTDEPLOY_DECL RobustVideoMatting : public FastDeployModel {
+ public:
+ /** \brief Set path of model file and configuration file, and the configuration of runtime
+ *
+ * \param[in] model_file Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx
+ * \param[in] params_file Path of parameter file, if the model format is ONNX, this parameter will be ignored
+ * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
+ * \param[in] model_format Model format of the loaded model, default is ONNX format
+ */
+ RobustVideoMatting(const std::string& model_file,
+ const std::string& params_file = "",
+ const RuntimeOption& custom_option = RuntimeOption(),
+ const ModelFormat& model_format = ModelFormat::ONNX);
+
+ /// Get model's name
+ std::string ModelName() const { return "matting/RobustVideoMatting"; }
+
+ /** \brief Predict the matting result for an input image
+ *
+ * \param[in] im The input image data, comes from cv::imread()
+ * \param[in] result The output matting result will be writen to this structure
+ * \return true if the prediction successed, otherwise false
+ */
+ bool Predict(cv::Mat* im, MattingResult* result);
+
+ /// Preprocess image size, the default is (1080, 1920)
+ std::vector size;
+
+ /// Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true // NOLINT
+ bool video_mode;
+
+ private:
+ bool Initialize();
+ /// Preprocess an input image, and set the preprocessed results to `outputs`
+ bool Preprocess(Mat* mat, FDTensor* output,
+ std::map>* im_info);
+
+ /// Postprocess the inferenced results, and set the final result to `result`
+ bool Postprocess(std::vector& infer_result, MattingResult* result,
+ const std::map>& im_info);
+
+ /// Init dynamic inputs datas
+ std::vector> dynamic_inputs_datas_ = {
+ {0.0f}, // r1i
+ {0.0f}, // r2i
+ {0.0f}, // r3i
+ {0.0f}, // r4i
+ {0.25f}, // downsample_ratio
+ };
+
+ /// Init dynamic inputs dims
+ std::vector> dynamic_inputs_dims_ = {
+ {1, 1, 1, 1}, // r1i
+ {1, 1, 1, 1}, // r2i
+ {1, 1, 1, 1}, // r3i
+ {1, 1, 1, 1}, // r4i
+ {1}, // downsample_ratio
+ };
+};
+
+} // namespace matting
+} // namespace vision
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/matting/contrib/rvm_pybind.cc b/fastdeploy/vision/matting/contrib/rvm_pybind.cc
new file mode 100755
index 000000000..a45816d65
--- /dev/null
+++ b/fastdeploy/vision/matting/contrib/rvm_pybind.cc
@@ -0,0 +1,34 @@
+// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "fastdeploy/pybind/main.h"
+
+namespace fastdeploy {
+void BindRobustVideoMatting(pybind11::module& m) {
+ // Bind RobustVideoMatting
+ pybind11::class_(m, "RobustVideoMatting")
+ .def(pybind11::init())
+ .def("predict",
+ [](vision::matting::RobustVideoMatting& self, pybind11::array& data) {
+ auto mat = PyArrayToCvMat(data);
+ vision::MattingResult res;
+ self.Predict(&mat, &res);
+ return res;
+ })
+ .def_readwrite("size", &vision::matting::RobustVideoMatting::size)
+ .def_readwrite("video_mode", &vision::matting::RobustVideoMatting::video_mode);
+}
+
+} // namespace fastdeploy
diff --git a/fastdeploy/vision/matting/matting_pybind.cc b/fastdeploy/vision/matting/matting_pybind.cc
index 8c514a428..204dd7192 100644
--- a/fastdeploy/vision/matting/matting_pybind.cc
+++ b/fastdeploy/vision/matting/matting_pybind.cc
@@ -17,12 +17,14 @@
namespace fastdeploy {
void BindMODNet(pybind11::module& m);
+void BindRobustVideoMatting(pybind11::module& m);
void BindPPMatting(pybind11::module& m);
void BindMatting(pybind11::module& m) {
auto matting_module =
- m.def_submodule("matting", "Image object matting models.");
+ m.def_submodule("matting", "Image/Video matting models.");
BindMODNet(matting_module);
+ BindRobustVideoMatting(matting_module);
BindPPMatting(matting_module);
}
} // namespace fastdeploy
diff --git a/fastdeploy/vision/matting/ppmatting/ppmatting.cc b/fastdeploy/vision/matting/ppmatting/ppmatting.cc
old mode 100644
new mode 100755
index 9c342d315..e760ab523
--- a/fastdeploy/vision/matting/ppmatting/ppmatting.cc
+++ b/fastdeploy/vision/matting/ppmatting/ppmatting.cc
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
-#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
@@ -163,8 +162,8 @@ bool PPMatting::Postprocess(
const std::map>& im_info) {
FDASSERT((infer_result.size() == 1),
"The default number of output tensor must be 1 ");
- FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1)
- FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch =1 now.");
+ FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w)
+ FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch = 1 now.");
if (alpha_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
diff --git a/fastdeploy/vision/segmentation/ppseg/model.cc b/fastdeploy/vision/segmentation/ppseg/model.cc
old mode 100644
new mode 100755
index cd28836fb..21c377485
--- a/fastdeploy/vision/segmentation/ppseg/model.cc
+++ b/fastdeploy/vision/segmentation/ppseg/model.cc
@@ -13,7 +13,6 @@
// limitations under the License.
#include "fastdeploy/vision/segmentation/ppseg/model.h"
-#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
diff --git a/python/fastdeploy/vision/matting/__init__.py b/python/fastdeploy/vision/matting/__init__.py
index 9a03d81e0..43b0acbcf 100644
--- a/python/fastdeploy/vision/matting/__init__.py
+++ b/python/fastdeploy/vision/matting/__init__.py
@@ -14,4 +14,5 @@
from __future__ import absolute_import
from .contrib.modnet import MODNet
+from .contrib.rvm import RobustVideoMatting
from .ppmatting import PPMatting
diff --git a/python/fastdeploy/vision/matting/contrib/rvm.py b/python/fastdeploy/vision/matting/contrib/rvm.py
new file mode 100755
index 000000000..144a3823c
--- /dev/null
+++ b/python/fastdeploy/vision/matting/contrib/rvm.py
@@ -0,0 +1,81 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+import logging
+from .... import FastDeployModel, ModelFormat
+from .... import c_lib_wrap as C
+
+
+class RobustVideoMatting(FastDeployModel):
+ def __init__(self,
+ model_file,
+ params_file="",
+ runtime_option=None,
+ model_format=ModelFormat.ONNX):
+ """Load a video matting model exported by RobustVideoMatting.
+
+ :param model_file: (str)Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx
+ :param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
+ :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
+ :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX
+ """
+ super(RobustVideoMatting, self).__init__(runtime_option)
+
+ self._model = C.vision.matting.RobustVideoMatting(
+ model_file, params_file, self._runtime_option, model_format)
+ assert self.initialized, "RobustVideoMatting initialize failed."
+
+ def predict(self, input_image):
+ """Matting an input image
+
+ :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
+ :return: MattingResult
+ """
+ return self._model.predict(input_image)
+
+ @property
+ def size(self):
+ """
+ Returns the preprocess image size
+ """
+ return self._model.size
+
+ @property
+ def video_mode(self):
+ """
+ Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true
+ """
+ return self._model.video_mode
+
+ @size.setter
+ def size(self, wh):
+ """
+ Set the preprocess image size
+ """
+ assert isinstance(wh, (list, tuple)),\
+ "The value to set `size` must be type of tuple or list."
+ assert len(wh) == 2,\
+ "The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
+ len(wh))
+ self._model.size = wh
+
+ @video_mode.setter
+ def video_mode(self, value):
+ """
+ Set video_mode property, the default is true
+ """
+ assert isinstance(
+ value, bool), "The value to set `video_mode` must be type of bool."
+ self._model.video_mode = value
diff --git a/tests/eval_example/test_rvm.py b/tests/eval_example/test_rvm.py
new file mode 100644
index 000000000..4b8d5afe8
--- /dev/null
+++ b/tests/eval_example/test_rvm.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import fastdeploy as fd
+import cv2
+import os
+import pickle
+import numpy as np
+
+
+def test_matting_rvm_cpu():
+ model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz"
+ input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4"
+ fd.download_and_decompress(model_url, ".")
+ fd.download(input_url, ".")
+ model_path = "rvm/rvm_mobilenetv3_fp32.onnx"
+ # use ORT
+ runtime_option = fd.RuntimeOption()
+ runtime_option.use_ort_backend()
+ model = fd.vision.matting.RobustVideoMatting(
+ model_path, runtime_option=runtime_option)
+
+ cap = cv2.VideoCapture(input_url)
+
+ frame_id = 0
+ while True:
+ _, frame = cap.read()
+ if frame is None:
+ break
+ result = model.predict(frame)
+ # compare diff
+ expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy")
+ result_alpha = np.array(result.alpha).reshape(1920, 1080)
+ diff = np.fabs(expect_alpha - result_alpha)
+ thres = 1e-05
+ assert diff.max(
+ ) < thres, "The label diff is %f, which is bigger than %f" % (
+ diff.max(), thres)
+ frame_id = frame_id + 1
+ cv2.waitKey(30)
+ if frame_id >= 10:
+ cap.release()
+ cv2.destroyAllWindows()
+ break
+
+
+def test_matting_rvm_gpu_trt():
+ model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz"
+ input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4"
+ fd.download_and_decompress(model_url, ".")
+ fd.download(input_url, ".")
+ model_path = "rvm/rvm_mobilenetv3_trt.onnx"
+ # use TRT
+ runtime_option = fd.RuntimeOption()
+ runtime_option.use_gpu()
+ runtime_option.use_trt_backend()
+ runtime_option.set_trt_input_shape("src", [1, 3, 1920, 1080])
+ runtime_option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135],
+ [1, 16, 240, 135])
+ runtime_option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68],
+ [1, 20, 120, 68])
+ runtime_option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34],
+ [1, 40, 60, 34])
+ runtime_option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17],
+ [1, 64, 30, 17])
+ model = fd.vision.matting.RobustVideoMatting(
+ model_path, runtime_option=runtime_option)
+
+ cap = cv2.VideoCapture("./video.mp4")
+
+ frame_id = 0
+ while True:
+ _, frame = cap.read()
+ if frame is None:
+ break
+ result = model.predict(frame)
+ # compare diff
+ expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy")
+ result_alpha = np.array(result.alpha).reshape(1920, 1080)
+ diff = np.fabs(expect_alpha - result_alpha)
+ thres = 1e-04
+ assert diff.max(
+ ) < thres, "The label diff is %f, which is bigger than %f" % (
+ diff.max(), thres)
+ frame_id = frame_id + 1
+ cv2.waitKey(30)
+ if frame_id >= 10:
+ cap.release()
+ cv2.destroyAllWindows()
+ break