Files
FastDeploy/examples/vision/sr/basicvsr/python/infer.py
ChaoII c7ec14de95 [Model] add vsr serials models (#518)
* [Model] add vsr serials models

Signed-off-by: ChaoII <849453582@qq.com>

* [Model] add vsr serials models

Signed-off-by: ChaoII <849453582@qq.com>

* fix build problem

Signed-off-by: ChaoII <849453582@qq.com>

* fix code style

Signed-off-by: ChaoII <849453582@qq.com>

* modify according to review suggestions

Signed-off-by: ChaoII <849453582@qq.com>

* modify vsr trt example

Signed-off-by: ChaoII <849453582@qq.com>

* update sr directory

* fix BindPPSR

* add doxygen comment

* add sr unit test

* update model file url

Signed-off-by: ChaoII <849453582@qq.com>
Co-authored-by: Jason <jiangjiajun@baidu.com>
2022-11-21 10:58:28 +08:00

87 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import os
import fastdeploy as fd
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path of model.")
parser.add_argument(
"--video", type=str, required=True, help="Path of test video file.")
parser.add_argument("--frame_num", type=int, default=2, help="frame num")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.enable_paddle_to_trt()
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model_file = os.path.join(args.model, "model.pdmodel")
params_file = os.path.join(args.model, "model.pdiparams")
model = fd.vision.sr.BasicVSR(
model_file, params_file, runtime_option=runtime_option)
# 该处应该与你导出模型的第二个维度一致模型输入shape=[b,n,c,h,w]
capture = cv2.VideoCapture(args.video)
video_out_name = "output.mp4"
video_fps = capture.get(cv2.CAP_PROP_FPS)
video_frame_count = capture.get(cv2.CAP_PROP_FRAME_COUNT)
# 注意导出模型时尺寸与原始输入的分辨一致比如:[1,2,3,180,320],经过4x超分后[1,2,3,720,1280]
# 所以导出模型相当重要(最关键的是根据netron查看网络输出shape)
out_width = 1280
out_height = 720
print(f"fps: {video_fps}\tframe_count: {video_frame_count}")
# Create VideoWriter for output
video_out_dir = "./"
video_out_path = os.path.join(video_out_dir, video_out_name)
fucc = cv2.VideoWriter_fourcc(*"mp4v")
video_out = cv2.VideoWriter(video_out_path, fucc, video_fps,
(out_width, out_height), True)
if not video_out.isOpened():
print("create video writer failed!")
# Capture all frames and do inference
frame_id = 0
reach_end = False
while capture.isOpened():
imgs = []
for i in range(args.frame_num):
_, frame = capture.read()
if frame is not None:
imgs.append(frame)
else:
reach_end = True
if reach_end:
break
results = model.predict(imgs)
for item in results:
# cv2.imshow("13", item)
# cv2.waitKey(30)
video_out.write(item)
print("Processing frame: ", frame_id)
frame_id += 1
print("inference finished, output video saved at: ", video_out_path)
capture.release()
video_out.release()