mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00

* [Model] add vsr serials models Signed-off-by: ChaoII <849453582@qq.com> * [Model] add vsr serials models Signed-off-by: ChaoII <849453582@qq.com> * fix build problem Signed-off-by: ChaoII <849453582@qq.com> * fix code style Signed-off-by: ChaoII <849453582@qq.com> * modify according to review suggestions Signed-off-by: ChaoII <849453582@qq.com> * modify vsr trt example Signed-off-by: ChaoII <849453582@qq.com> * update sr directory * fix BindPPSR * add doxygen comment * add sr unit test * update model file url Signed-off-by: ChaoII <849453582@qq.com> Co-authored-by: Jason <jiangjiajun@baidu.com>
87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import cv2
|
||
import os
|
||
import fastdeploy as fd
|
||
|
||
|
||
def parse_arguments():
|
||
import argparse
|
||
import ast
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--model", required=True, help="Path of model.")
|
||
parser.add_argument(
|
||
"--video", type=str, required=True, help="Path of test video file.")
|
||
parser.add_argument("--frame_num", type=int, default=2, help="frame num")
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default='cpu',
|
||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||
parser.add_argument(
|
||
"--use_trt",
|
||
type=ast.literal_eval,
|
||
default=False,
|
||
help="Wether to use tensorrt.")
|
||
return parser.parse_args()
|
||
|
||
|
||
def build_option(args):
|
||
option = fd.RuntimeOption()
|
||
if args.device.lower() == "gpu":
|
||
option.use_gpu()
|
||
if args.use_trt:
|
||
option.use_trt_backend()
|
||
option.enable_paddle_to_trt()
|
||
return option
|
||
|
||
|
||
args = parse_arguments()
|
||
|
||
# 配置runtime,加载模型
|
||
runtime_option = build_option(args)
|
||
model_file = os.path.join(args.model, "model.pdmodel")
|
||
params_file = os.path.join(args.model, "model.pdiparams")
|
||
model = fd.vision.sr.BasicVSR(
|
||
model_file, params_file, runtime_option=runtime_option)
|
||
|
||
# 该处应该与你导出模型的第二个维度一致模型输入shape=[b,n,c,h,w]
|
||
capture = cv2.VideoCapture(args.video)
|
||
video_out_name = "output.mp4"
|
||
video_fps = capture.get(cv2.CAP_PROP_FPS)
|
||
video_frame_count = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||
# 注意导出模型时尺寸与原始输入的分辨一致比如:[1,2,3,180,320],经过4x超分后[1,2,3,720,1280]
|
||
# 所以导出模型相当重要(最关键的是根据netron查看网络输出shape)
|
||
out_width = 1280
|
||
out_height = 720
|
||
print(f"fps: {video_fps}\tframe_count: {video_frame_count}")
|
||
# Create VideoWriter for output
|
||
video_out_dir = "./"
|
||
video_out_path = os.path.join(video_out_dir, video_out_name)
|
||
fucc = cv2.VideoWriter_fourcc(*"mp4v")
|
||
video_out = cv2.VideoWriter(video_out_path, fucc, video_fps,
|
||
(out_width, out_height), True)
|
||
if not video_out.isOpened():
|
||
print("create video writer failed!")
|
||
# Capture all frames and do inference
|
||
frame_id = 0
|
||
reach_end = False
|
||
while capture.isOpened():
|
||
imgs = []
|
||
for i in range(args.frame_num):
|
||
_, frame = capture.read()
|
||
if frame is not None:
|
||
imgs.append(frame)
|
||
else:
|
||
reach_end = True
|
||
if reach_end:
|
||
break
|
||
results = model.predict(imgs)
|
||
for item in results:
|
||
# cv2.imshow("13", item)
|
||
# cv2.waitKey(30)
|
||
video_out.write(item)
|
||
print("Processing frame: ", frame_id)
|
||
frame_id += 1
|
||
print("inference finished, output video saved at: ", video_out_path)
|
||
capture.release()
|
||
video_out.release()
|