[Model] add RobustVideoMatting model (#400)

* add yolov5cls

* fixed bugs

* fixed bugs

* fixed preprocess bug

* add yolov5cls readme

* deal with comments

* Add YOLOv5Cls Note

* add yolov5cls test

* add rvm support

* support rvm model

* add rvm demo

* fixed bugs

* add rvm readme

* add TRT support

* add trt support

* add rvm test

* add EXPORT.md

* rename export.md

* rm poros doxyen

* deal with comments

* deal with comments

* add rvm video_mode note

Co-authored-by: Jason <jiangjiajun@baidu.com>
Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
WJJ1995
2022-10-26 14:30:04 +08:00
committed by GitHub
parent ba501fd963
commit 718698a32a
22 changed files with 1080 additions and 16 deletions

1
examples/vision/matting/README.md Normal file → Executable file
View File

@@ -5,6 +5,7 @@ FastDeploy目前支持如下抠图模型部署
| 模型 | 说明 | 模型格式 | 版本 |
| :--- | :--- | :------- | :--- |
| [ZHKKKe/MODNet](./modnet) | MODNet 系列模型 | ONNX | [CommitID:28165a4](https://github.com/ZHKKKe/MODNet/commit/28165a4) |
| [PeterL1n/RobustVideoMatting](./rvm) | RobustVideoMatting 系列模型 | ONNX | [CommitID:81a1093](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093) |
| [PaddleSeg/PP-Matting](./ppmatting) | PP-Matting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |
| [PaddleSeg/PP-HumanMatting](./ppmatting) | PP-HumanMatting 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |
| [PaddleSeg/ModNet](./ppmatting) | ModNet 系列模型 | Paddle | [Release/2.6](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) |

6
examples/vision/matting/modnet/python/README.md Normal file → Executable file
View File

@@ -52,16 +52,14 @@ MODNet模型加载和初始化其中model_file为导出的ONNX模型格式
### predict函数
> ```python
> MODNet.predict(image_data, conf_threshold=0.25, nms_iou_threshold=0.5)
> MODNet.predict(image_data)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
> 模型预测结口,输入图像直接输出抠图结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> > * **conf_threshold**(float): 检测框置信度过滤阈值
> > * **nms_iou_threshold**(float): NMS处理过程中iou阈值
> **返回**
>

View File

@@ -0,0 +1,30 @@
# RobustVideoMatting 模型部署
## 模型版本说明
- [RobustVideoMatting](https://github.com/PeterL1n/RobustVideoMatting/commit/81a1093)
## 支持模型列表
目前FastDeploy支持如下模型的部署
- [RobustVideoMatting 模型](https://github.com/PeterL1n/RobustVideoMatting)
## 下载预训练模型
为了方便开发者的测试下面提供了RobustVideoMatting导出的各系列模型开发者可直接下载使用。
| 模型 | 参数大小 | 精度 | 备注 |
|:---------------------------------------------------------------- |:----- |:----- | :------ |
| [rvm_mobilenetv3_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx) | 15MB | - |
| [rvm_resnet50_fp32.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_fp32.onnx) | 103MB | - |
| [rvm_mobilenetv3_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx) | 15MB | - |
| [rvm_resnet50_trt.onnx](https://bj.bcebos.com/paddlehub/fastdeploy/rvm_resnet50_trt.onnx) | 103MB | - |
**Note**
- 如果要使用 TensorRT 进行推理,需要下载后缀为 trt 的 onnx 模型文件
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,87 @@
# RobustVideoMatting C++部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上 RobustVideoMatting 推理为例在本目录执行如下命令即可完成编译测试如若只需在CPU上部署可在[Fastdeploy C++预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md/CPP_prebuilt_libraries.md)下载CPU推理库
本目录下提供`infer.cc`快速完成RobustVideoMatting在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载SDK编译模型examples代码SDK中包含了examples代码
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.3.0.tgz
tar xvf fastdeploy-linux-x64-gpu-0.3.0.tgz
cd fastdeploy-linux-x64-gpu-0.3.0/examples/vision/matting/rvm/cpp/
mkdir build && cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.3.0
make -j
# 下载RobustVideoMatting模型文件和测试图片以及视频
## 原版ONNX模型
wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx
## 为加载TRT特殊处理ONNX模型
wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx
wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4
# CPU推理
./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 0
# GPU推理
./infer_demo rvm_mobilenetv3_fp32.onnx matting_input.jpg matting_bgr.jpg 1
# TRT推理
./infer_demo rvm_mobilenetv3_trt.onnx matting_input.jpg matting_bgr.jpg 2
```
运行完成可视化结果如下图所示
<div width="840">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852040-759da522-fca4-4786-9205-88c622cd4a39.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852587-48895efc-d24a-43c9-aeec-d7b0362ab2b9.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852116-cf91445b-3a67-45d9-a675-c69fe77c383a.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852554-6960659f-4fd7-4506-b33b-54e1a9dd89bf.jpg">
</div>
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
## RobustVideoMatting C++接口
```c++
fastdeploy::vision::matting::RobustVideoMatting(
const string& model_file,
const string& params_file = "",
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX)
```
RobustVideoMatting模型加载和初始化其中model_file为导出的ONNX模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为ONNX格式
#### Predict函数
> ```c++
> RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result)
> ```
>
> 模型预测接口,输入图像直接输出抠图结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 抠图结果, MattingResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,131 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_file, const std::string& image_file,
const std::string& background_file) {
auto option = fastdeploy::RuntimeOption();
auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
cv::Mat bg = cv::imread(background_file);
fastdeploy::vision::MattingResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
auto vis_im_with_bg =
fastdeploy::vision::SwapBackground(im_bak, bg, res);
cv::imwrite("visualized_result.jpg", vis_im_with_bg);
cv::imwrite("visualized_result_fg.jpg", vis_im);
std::cout << "Visualized result save in ./visualized_result.jpg "
"and ./visualized_result_fg.jpg"
<< std::endl;
}
void GpuInfer(const std::string& model_file, const std::string& image_file,
const std::string& background_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
cv::Mat bg = cv::imread(background_file);
fastdeploy::vision::MattingResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
auto vis_im_with_bg =
fastdeploy::vision::SwapBackground(im_bak, bg, res);
cv::imwrite("visualized_result.jpg", vis_im_with_bg);
cv::imwrite("visualized_result_fg.jpg", vis_im);
std::cout << "Visualized result save in ./visualized_result_replaced_bg.jpg "
"and ./visualized_result_fg.jpg"
<< std::endl;
}
void TrtInfer(const std::string& model_file, const std::string& image_file,
const std::string& background_file) {
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
option.SetTrtInputShape("src", {1, 3, 1920, 1080});
option.SetTrtInputShape("r1i", {1, 1, 1, 1}, {1, 16, 240, 135}, {1, 16, 240, 135});
option.SetTrtInputShape("r2i", {1, 1, 1, 1}, {1, 20, 120, 68}, {1, 20, 120, 68});
option.SetTrtInputShape("r3i", {1, 1, 1, 1}, {1, 40, 60, 34}, {1, 40, 60, 34});
option.SetTrtInputShape("r4i", {1, 1, 1, 1}, {1, 64, 30, 17}, {1, 64, 30, 17});
auto model = fastdeploy::vision::matting::RobustVideoMatting(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
cv::Mat bg = cv::imread(background_file);
fastdeploy::vision::MattingResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
auto vis_im = fastdeploy::vision::VisMatting(im_bak, res);
auto vis_im_with_bg =
fastdeploy::vision::SwapBackground(im_bak, bg, res);
cv::imwrite("visualized_result.jpg", vis_im_with_bg);
cv::imwrite("visualized_result_fg.jpg", vis_im);
std::cout << "Visualized result save in ./visualized_result.jpg "
"and ./visualized_result_fg.jpg"
<< std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 5) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/image run_option, "
"e.g ./infer_model ./rvm_mobilenetv3_fp32.onnx ./test.jpg ./test_bg.jpg 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[4]) == 0) {
CpuInfer(argv[1], argv[2], argv[3]);
} else if (std::atoi(argv[4]) == 1) {
GpuInfer(argv[1], argv[2], argv[3]);
} else if (std::atoi(argv[4]) == 2) {
TrtInfer(argv[1], argv[2], argv[3]);
}
return 0;
}

View File

@@ -0,0 +1,116 @@
# RobustVideoMatting 支持TRT的动态ONNX导出
## 环境依赖
- python >= 3.5
- pytorch 1.12.0
- onnx 1.10.0
- onnxsim 0.4.8
## 步骤一:拉取 RobustVideoMatting onnx 分支代码
```shell
git clone -b onnx https://github.com/PeterL1n/RobustVideoMatting.git
cd RobustVideoMatting
```
## 步骤二:去掉 downsample_ratio 动态输入
在```model/model.py```中,将 ```downsample_ratio``` 输入去掉,如下图所示
```python
def forward(self, src, r1, r2, r3, r4,
# downsample_ratio: float = 0.25,
segmentation_pass: bool = False):
if torch.onnx.is_in_onnx_export():
# src_sm = CustomOnnxResizeByFactorOp.apply(src, 0.25)
src_sm = self._interpolate(src, scale_factor=0.25)
elif downsample_ratio != 1:
src_sm = self._interpolate(src, scale_factor=0.25)
else:
src_sm = src
f1, f2, f3, f4 = self.backbone(src_sm)
f4 = self.aspp(f4)
hid, *rec = self.decoder(src_sm, f1, f2, f3, f4, r1, r2, r3, r4)
if not segmentation_pass:
fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3)
# if torch.onnx.is_in_onnx_export() or downsample_ratio != 1:
if torch.onnx.is_in_onnx_export():
fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid)
fgr = fgr_residual + src
fgr = fgr.clamp(0., 1.)
pha = pha.clamp(0., 1.)
return [fgr, pha, *rec]
else:
seg = self.project_seg(hid)
return [seg, *rec]
```
## 步骤三:修改导出 ONNX 脚本
修改```export_onnx.py```脚本,去掉```downsample_ratio```输入
```python
def export(self):
rec = (torch.zeros([1, 1, 1, 1]).to(self.args.device, self.precision),) * 4
# src = torch.randn(1, 3, 1080, 1920).to(self.args.device, self.precision)
src = torch.randn(1, 3, 1920, 1080).to(self.args.device, self.precision)
# downsample_ratio = torch.tensor([0.25]).to(self.args.device)
dynamic_spatial = {0: 'batch_size', 2: 'height', 3: 'width'}
dynamic_everything = {0: 'batch_size', 1: 'channels', 2: 'height', 3: 'width'}
torch.onnx.export(
self.model,
# (src, *rec, downsample_ratio),
(src, *rec),
self.args.output,
export_params=True,
opset_version=self.args.opset,
do_constant_folding=True,
# input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i', 'downsample_ratio'],
input_names=['src', 'r1i', 'r2i', 'r3i', 'r4i'],
output_names=['fgr', 'pha', 'r1o', 'r2o', 'r3o', 'r4o'],
dynamic_axes={
'src': {0: 'batch_size0', 2: 'height0', 3: 'width0'},
'fgr': {0: 'batch_size1', 2: 'height1', 3: 'width1'},
'pha': {0: 'batch_size2', 2: 'height2', 3: 'width2'},
'r1i': {0: 'batch_size3', 1: 'channels3', 2: 'height3', 3: 'width3'},
'r2i': {0: 'batch_size4', 1: 'channels4', 2: 'height4', 3: 'width4'},
'r3i': {0: 'batch_size5', 1: 'channels5', 2: 'height5', 3: 'width5'},
'r4i': {0: 'batch_size6', 1: 'channels6', 2: 'height6', 3: 'width6'},
'r1o': {0: 'batch_size7', 2: 'height7', 3: 'width7'},
'r2o': {0: 'batch_size8', 2: 'height8', 3: 'width8'},
'r3o': {0: 'batch_size9', 2: 'height9', 3: 'width9'},
'r4o': {0: 'batch_size10', 2: 'height10', 3: 'width10'},
})
```
运行下列命令
```shell
python export_onnx.py \
--model-variant mobilenetv3 \
--checkpoint rvm_mobilenetv3.pth \
--precision float32 \
--opset 12 \
--device cuda \
--output rvm_mobilenetv3.onnx
```
**Note**
- trt关于多输入ONNX模型的dynamic shape如果x0和x1的shape不同不能都以height、width去表示要以height0、height1去区分要不然build engine阶段会出错
## 步骤四使用onnxsim简化
安装 onnxsim并简化步骤三导出的 ONNX 模型
```shell
pip install onnxsim
onnxsim rvm_mobilenetv3.onnx rvm_mobilenetv3_trt.onnx
```
```rvm_mobilenetv3_trt.onnx```即为可运行 TRT 后端的动态 shape 的 ONNX 模型

View File

@@ -0,0 +1,88 @@
# RobustVideoMatting Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供`infer.py`快速完成RobustVideoMatting在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/matting/rvm/python
# 下载RobustVideoMatting模型文件和测试图片以及视频
## 原版ONNX模型
wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_fp32.onnx
## 为加载TRT特殊处理ONNX模型
wget https://bj.bcebos.com/paddlehub/fastdeploy/rvm_mobilenetv3_trt.onnx
wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_input.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/matting_bgr.jpg
wget https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4
# CPU推理
## 图片
python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device cpu
## 视频
python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device cpu
# GPU推理
## 图片
python infer.py --model rvm_mobilenetv3_fp32.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu
## 视频
python infer.py --model rvm_mobilenetv3_fp32.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu
# TRT推理
## 图片
python infer.py --model rvm_mobilenetv3_trt.onnx --image matting_input.jpg --bg matting_bgr.jpg --device gpu --use_trt True
## 视频
python infer.py --model rvm_mobilenetv3_trt.onnx --video video.mp4 --bg matting_bgr.jpg --device gpu --use_trt True
```
运行完成可视化结果如下图所示
<div width="1240">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852040-759da522-fca4-4786-9205-88c622cd4a39.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852587-48895efc-d24a-43c9-aeec-d7b0362ab2b9.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852116-cf91445b-3a67-45d9-a675-c69fe77c383a.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/67993288/186852554-6960659f-4fd7-4506-b33b-54e1a9dd89bf.jpg">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/19977378/196653716-f7043bd5-dfc2-4e7d-be0f-e12a6af4c55b.gif">
<img width="200" height="200" float="left" src="https://user-images.githubusercontent.com/19977378/196654529-866bff5d-47a2-4584-9627-39b587799228.gif">
</div>
## RobustVideoMatting Python接口
```python
fd.vision.matting.RobustVideoMatting(model_file, params_file=None, runtime_option=None, model_format=ModelFormat.ONNX)
```
RobustVideoMatting模型加载和初始化其中model_file为导出的ONNX模型格式
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径当模型格式为ONNX格式时此参数无需设定
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为ONNX
### predict函数
> ```python
> RobustVideoMatting.predict(input_image)
> ```
>
> 模型预测结口,输入图像直接输出抠图结果。
>
> **参数**
>
> > * **input_image**(np.ndarray): 输入数据注意需为HWCBGR格式
> **返回**
>
> > 返回`fastdeploy.vision.MattingResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [RobustVideoMatting 模型介绍](..)
- [RobustVideoMatting C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,112 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of RobustVideoMatting model.")
parser.add_argument("--image", type=str, help="Path of test image file.")
parser.add_argument("--video", type=str, help="Path of test video file.")
parser.add_argument(
"--bg",
type=str,
required=True,
default=None,
help="Path of test background image file.")
parser.add_argument(
'--output-composition',
type=str,
default="composition.mp4",
help="Path of composition video file.")
parser.add_argument(
'--output-alpha',
type=str,
default="alpha.mp4",
help="Path of alpha video file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.set_trt_input_shape("src", [1, 3, 1920, 1080])
option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135],
[1, 16, 240, 135])
option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68],
[1, 20, 120, 68])
option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34],
[1, 40, 60, 34])
option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17],
[1, 64, 30, 17])
return option
args = parse_arguments()
output_composition = args.output_composition
output_alpha = args.output_alpha
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.matting.RobustVideoMatting(
args.model, runtime_option=runtime_option)
bg = cv2.imread(args.bg)
if args.video is not None:
# for video
cap = cv2.VideoCapture(args.video)
# Define the codec and create VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
composition = cv2.VideoWriter(output_composition, fourcc, 20.0,
(1080, 1920))
alpha = cv2.VideoWriter(output_alpha, fourcc, 20.0, (1080, 1920))
frame_id = 0
while True:
frame_id = frame_id + 1
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
vis_im = fd.vision.vis_matting(frame, result)
vis_im_with_bg = fd.vision.swap_background_matting(frame, bg, result)
alpha.write(vis_im)
composition.write(vis_im_with_bg)
cv2.waitKey(30)
cap.release()
composition.release()
alpha.release()
cv2.destroyAllWindows()
print("Visualized result video save in {} and {}".format(
output_composition, output_alpha))
if args.image is not None:
# for image
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 可视化结果
vis_im = fd.vision.vis_matting(im, result)
vis_im_with_bg = fd.vision.swap_background_matting(im, bg, result)
cv2.imwrite("visualized_result_fg.jpg", vis_im)
cv2.imwrite("visualized_result_replaced_bg.jpg", vis_im_with_bg)
print(
"Visualized result save in ./visualized_result_replaced_bg.jpg and ./visualized_result_fg.jpg"
)

View File

@@ -27,11 +27,6 @@ namespace baidu {
namespace mirana {
namespace poros {
/**
* the base engine class
* every registered engine should inherit from this IEngine
**/
struct PorosGraph {
torch::jit::Graph* graph = NULL;
torch::jit::Node* node = NULL;

2
fastdeploy/backends/tensorrt/trt_backend.cc Normal file → Executable file
View File

@@ -525,7 +525,7 @@ bool TrtBackend::BuildTrtEngine() {
engine_->createExecutionContext());
GetInputOutputInfo();
FDINFO << "TensorRT Engine is built succussfully." << std::endl;
FDINFO << "TensorRT Engine is built successfully." << std::endl;
if (option_.serialize_file != "") {
FDINFO << "Serialize TensorRTEngine to local file "
<< option_.serialize_file << "." << std::endl;

View File

@@ -41,6 +41,7 @@
#include "fastdeploy/vision/faceid/contrib/vpl.h"
#include "fastdeploy/vision/keypointdet/pptinypose/pptinypose.h"
#include "fastdeploy/vision/matting/contrib/modnet.h"
#include "fastdeploy/vision/matting/contrib/rvm.h"
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
#include "fastdeploy/vision/ocr/ppocr/classifier.h"
#include "fastdeploy/vision/ocr/ppocr/dbdetector.h"

2
fastdeploy/vision/matting/contrib/modnet.cc Normal file → Executable file
View File

@@ -86,7 +86,7 @@ bool MODNet::Postprocess(
FDASSERT((infer_result.size() == 1),
"The default number of output tensor must be 1 according to "
"modnet.");
FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1)
FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w)
FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch =1 now.");
if (alpha_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;

View File

@@ -0,0 +1,182 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/matting/contrib/rvm.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace matting {
RobustVideoMatting::RobustVideoMatting(const std::string& model_file, const std::string& params_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format) {
if (model_format == ModelFormat::ONNX) {
valid_cpu_backends = {Backend::OPENVINO, Backend::ORT};
valid_gpu_backends = {Backend::ORT, Backend::TRT};
} else {
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
}
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool RobustVideoMatting::Initialize() {
// parameters for preprocess
size = {1080, 1920};
video_mode = true;
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool RobustVideoMatting::Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<int, 2>>* im_info) {
// Resize
int resize_w = size[0];
int resize_h = size[1];
if (resize_h != mat->Height() || resize_w != mat->Width()) {
Resize::Run(mat, resize_w, resize_h);
}
BGR2RGB::Run(mat);
// Normalize
std::vector<float> alpha = {1.0f / 255.0f, 1.0f / 255.0f, 1.0f / 255.0f};
std::vector<float> beta = {0.0f, 0.0f, 0.0f};
Convert::Run(mat, alpha, beta);
// Record output shape of preprocessed image
(*im_info)["output_shape"] = {mat->Height(), mat->Width()};
HWC2CHW::Run(mat);
Cast::Run(mat, "float");
mat->ShareWithTensor(output);
output->shape.insert(output->shape.begin(), 1); // reshape to n, h, w, c
return true;
}
bool RobustVideoMatting::Postprocess(
std::vector<FDTensor>& infer_result, MattingResult* result,
const std::map<std::string, std::array<int, 2>>& im_info) {
FDASSERT((infer_result.size() == 6),
"The default number of output tensor must be 6 according to "
"RobustVideoMatting.");
FDTensor& fgr = infer_result.at(0); // fgr (1, 3, h, w) 0.~1.
FDTensor& alpha = infer_result.at(1); // alpha (1, 1, h, w) 0.~1.
FDASSERT((fgr.shape[0] == 1), "Only support batch = 1 now.");
FDASSERT((alpha.shape[0] == 1), "Only support batch = 1 now.");
if (fgr.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
if (alpha.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
// update context
if (video_mode) {
for (size_t i = 0; i < 4; ++i) {
FDTensor& rki = infer_result.at(i+2);
dynamic_inputs_dims_[i] = rki.shape;
dynamic_inputs_datas_[i].resize(rki.Numel());
memcpy(dynamic_inputs_datas_[i].data(), rki.Data(),
rki.Numel() * FDDataTypeSize(rki.dtype));
}
}
auto iter_in = im_info.find("input_shape");
auto iter_out = im_info.find("output_shape");
FDASSERT(iter_out != im_info.end() && iter_in != im_info.end(),
"Cannot find input_shape or output_shape from im_info.");
int out_h = iter_out->second[0];
int out_w = iter_out->second[1];
int in_h = iter_in->second[0];
int in_w = iter_in->second[1];
// for alpha
float* alpha_ptr = static_cast<float*>(alpha.Data());
cv::Mat alpha_zero_copy_ref(out_h, out_w, CV_32FC1, alpha_ptr);
Mat alpha_resized(alpha_zero_copy_ref); // ref-only, zero copy.
if ((out_h != in_h) || (out_w != in_w)) {
// already allocated a new continuous memory after resize.
Resize::Run(&alpha_resized, in_w, in_h, -1, -1);
}
// for foreground
float* fgr_ptr = static_cast<float*>(fgr.Data());
cv::Mat fgr_zero_copy_ref(out_h, out_w, CV_32FC1, fgr_ptr);
Mat fgr_resized(fgr_zero_copy_ref); // ref-only, zero copy.
if ((out_h != in_h) || (out_w != in_w)) {
// already allocated a new continuous memory after resize.
Resize::Run(&fgr_resized, in_w, in_h, -1, -1);
}
result->Clear();
result->contain_foreground = true;
result->shape = {static_cast<int64_t>(in_h), static_cast<int64_t>(in_w)};
int numel = in_h * in_w;
int nbytes = numel * sizeof(float);
result->Resize(numel);
memcpy(result->alpha.data(), alpha_resized.GetOpenCVMat()->data, nbytes);
memcpy(result->foreground.data(), fgr_resized.GetOpenCVMat()->data, nbytes);
return true;
}
bool RobustVideoMatting::Predict(cv::Mat* im, MattingResult* result) {
Mat mat(*im);
int inputs_nums = NumInputsOfRuntime();
std::vector<FDTensor> input_tensors(inputs_nums);
std::map<std::string, std::array<int, 2>> im_info;
// Record the shape of image and the shape of preprocessed image
im_info["input_shape"] = {mat.Height(), mat.Width()};
im_info["output_shape"] = {mat.Height(), mat.Width()};
// convert vector to FDTensor
for (size_t i = 1; i < inputs_nums; ++i) {
input_tensors[i].SetExternalData(dynamic_inputs_dims_[i-1], FDDataType::FP32, dynamic_inputs_datas_[i-1].data());
input_tensors[i].device = Device::CPU;
}
if (!Preprocess(&mat, &input_tensors[0], &im_info)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
for (size_t i = 0; i < inputs_nums; ++i) {
input_tensors[i].name = InputInfoOfRuntime(i).name;
}
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors, result, im_info)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
} // namespace matting
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,92 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
/** \brief All image/video matting model APIs are defined inside this namespace
*
*/
namespace matting {
/*! @brief RobustVideoMatting model object used when to load a RobustVideoMatting model exported by RobustVideoMatting
*/
class FASTDEPLOY_DECL RobustVideoMatting : public FastDeployModel {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx
* \param[in] params_file Path of parameter file, if the model format is ONNX, this parameter will be ignored
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is ONNX format
*/
RobustVideoMatting(const std::string& model_file,
const std::string& params_file = "",
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::ONNX);
/// Get model's name
std::string ModelName() const { return "matting/RobustVideoMatting"; }
/** \brief Predict the matting result for an input image
*
* \param[in] im The input image data, comes from cv::imread()
* \param[in] result The output matting result will be writen to this structure
* \return true if the prediction successed, otherwise false
*/
bool Predict(cv::Mat* im, MattingResult* result);
/// Preprocess image size, the default is (1080, 1920)
std::vector<int> size;
/// Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true // NOLINT
bool video_mode;
private:
bool Initialize();
/// Preprocess an input image, and set the preprocessed results to `outputs`
bool Preprocess(Mat* mat, FDTensor* output,
std::map<std::string, std::array<int, 2>>* im_info);
/// Postprocess the inferenced results, and set the final result to `result`
bool Postprocess(std::vector<FDTensor>& infer_result, MattingResult* result,
const std::map<std::string, std::array<int, 2>>& im_info);
/// Init dynamic inputs datas
std::vector<std::vector<float>> dynamic_inputs_datas_ = {
{0.0f}, // r1i
{0.0f}, // r2i
{0.0f}, // r3i
{0.0f}, // r4i
{0.25f}, // downsample_ratio
};
/// Init dynamic inputs dims
std::vector<std::vector<int64_t>> dynamic_inputs_dims_ = {
{1, 1, 1, 1}, // r1i
{1, 1, 1, 1}, // r2i
{1, 1, 1, 1}, // r3i
{1, 1, 1, 1}, // r4i
{1}, // downsample_ratio
};
};
} // namespace matting
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,34 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindRobustVideoMatting(pybind11::module& m) {
// Bind RobustVideoMatting
pybind11::class_<vision::matting::RobustVideoMatting, FastDeployModel>(m, "RobustVideoMatting")
.def(pybind11::init<std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::matting::RobustVideoMatting& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::MattingResult res;
self.Predict(&mat, &res);
return res;
})
.def_readwrite("size", &vision::matting::RobustVideoMatting::size)
.def_readwrite("video_mode", &vision::matting::RobustVideoMatting::video_mode);
}
} // namespace fastdeploy

View File

@@ -17,12 +17,14 @@
namespace fastdeploy {
void BindMODNet(pybind11::module& m);
void BindRobustVideoMatting(pybind11::module& m);
void BindPPMatting(pybind11::module& m);
void BindMatting(pybind11::module& m) {
auto matting_module =
m.def_submodule("matting", "Image object matting models.");
m.def_submodule("matting", "Image/Video matting models.");
BindMODNet(matting_module);
BindRobustVideoMatting(matting_module);
BindPPMatting(matting_module);
}
} // namespace fastdeploy

3
fastdeploy/vision/matting/ppmatting/ppmatting.cc Normal file → Executable file
View File

@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/matting/ppmatting/ppmatting.h"
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"
@@ -163,7 +162,7 @@ bool PPMatting::Postprocess(
const std::map<std::string, std::array<int, 2>>& im_info) {
FDASSERT((infer_result.size() == 1),
"The default number of output tensor must be 1 ");
FDTensor& alpha_tensor = infer_result.at(0); // (1,h,w,1)
FDTensor& alpha_tensor = infer_result.at(0); // (1, 1, h, w)
FDASSERT((alpha_tensor.shape[0] == 1), "Only support batch = 1 now.");
if (alpha_tensor.dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;

1
fastdeploy/vision/segmentation/ppseg/model.cc Normal file → Executable file
View File

@@ -13,7 +13,6 @@
// limitations under the License.
#include "fastdeploy/vision/segmentation/ppseg/model.h"
#include "fastdeploy/vision.h"
#include "fastdeploy/vision/utils/utils.h"
#include "yaml-cpp/yaml.h"

View File

@@ -14,4 +14,5 @@
from __future__ import absolute_import
from .contrib.modnet import MODNet
from .contrib.rvm import RobustVideoMatting
from .ppmatting import PPMatting

View File

@@ -0,0 +1,81 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class RobustVideoMatting(FastDeployModel):
def __init__(self,
model_file,
params_file="",
runtime_option=None,
model_format=ModelFormat.ONNX):
"""Load a video matting model exported by RobustVideoMatting.
:param model_file: (str)Path of model file, e.g rvm/rvm_mobilenetv3_fp32.onnx
:param params_file: (str)Path of parameters file, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model, default is ONNX
"""
super(RobustVideoMatting, self).__init__(runtime_option)
self._model = C.vision.matting.RobustVideoMatting(
model_file, params_file, self._runtime_option, model_format)
assert self.initialized, "RobustVideoMatting initialize failed."
def predict(self, input_image):
"""Matting an input image
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: MattingResult
"""
return self._model.predict(input_image)
@property
def size(self):
"""
Returns the preprocess image size
"""
return self._model.size
@property
def video_mode(self):
"""
Whether to open the video mode, if there are some irrelevant pictures, set it to fasle, the default is true
"""
return self._model.video_mode
@size.setter
def size(self, wh):
"""
Set the preprocess image size
"""
assert isinstance(wh, (list, tuple)),\
"The value to set `size` must be type of tuple or list."
assert len(wh) == 2,\
"The value to set `size` must contatins 2 elements means [width, height], but now it contains {} elements.".format(
len(wh))
self._model.size = wh
@video_mode.setter
def video_mode(self, value):
"""
Set video_mode property, the default is true
"""
assert isinstance(
value, bool), "The value to set `video_mode` must be type of bool."
self._model.video_mode = value

View File

@@ -0,0 +1,101 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import pickle
import numpy as np
def test_matting_rvm_cpu():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz"
input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4"
fd.download_and_decompress(model_url, ".")
fd.download(input_url, ".")
model_path = "rvm/rvm_mobilenetv3_fp32.onnx"
# use ORT
runtime_option = fd.RuntimeOption()
runtime_option.use_ort_backend()
model = fd.vision.matting.RobustVideoMatting(
model_path, runtime_option=runtime_option)
cap = cv2.VideoCapture(input_url)
frame_id = 0
while True:
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
# compare diff
expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy")
result_alpha = np.array(result.alpha).reshape(1920, 1080)
diff = np.fabs(expect_alpha - result_alpha)
thres = 1e-05
assert diff.max(
) < thres, "The label diff is %f, which is bigger than %f" % (
diff.max(), thres)
frame_id = frame_id + 1
cv2.waitKey(30)
if frame_id >= 10:
cap.release()
cv2.destroyAllWindows()
break
def test_matting_rvm_gpu_trt():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/rvm.tgz"
input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/video.mp4"
fd.download_and_decompress(model_url, ".")
fd.download(input_url, ".")
model_path = "rvm/rvm_mobilenetv3_trt.onnx"
# use TRT
runtime_option = fd.RuntimeOption()
runtime_option.use_gpu()
runtime_option.use_trt_backend()
runtime_option.set_trt_input_shape("src", [1, 3, 1920, 1080])
runtime_option.set_trt_input_shape("r1i", [1, 1, 1, 1], [1, 16, 240, 135],
[1, 16, 240, 135])
runtime_option.set_trt_input_shape("r2i", [1, 1, 1, 1], [1, 20, 120, 68],
[1, 20, 120, 68])
runtime_option.set_trt_input_shape("r3i", [1, 1, 1, 1], [1, 40, 60, 34],
[1, 40, 60, 34])
runtime_option.set_trt_input_shape("r4i", [1, 1, 1, 1], [1, 64, 30, 17],
[1, 64, 30, 17])
model = fd.vision.matting.RobustVideoMatting(
model_path, runtime_option=runtime_option)
cap = cv2.VideoCapture("./video.mp4")
frame_id = 0
while True:
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
# compare diff
expect_alpha = np.load("rvm/result_alpha_" + str(frame_id) + ".npy")
result_alpha = np.array(result.alpha).reshape(1920, 1080)
diff = np.fabs(expect_alpha - result_alpha)
thres = 1e-04
assert diff.max(
) < thres, "The label diff is %f, which is bigger than %f" % (
diff.max(), thres)
frame_id = frame_id + 1
cv2.waitKey(30)
if frame_id >= 10:
cap.release()
cv2.destroyAllWindows()
break