mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Model] add vsr serials models (#518)
* [Model] add vsr serials models Signed-off-by: ChaoII <849453582@qq.com> * [Model] add vsr serials models Signed-off-by: ChaoII <849453582@qq.com> * fix build problem Signed-off-by: ChaoII <849453582@qq.com> * fix code style Signed-off-by: ChaoII <849453582@qq.com> * modify according to review suggestions Signed-off-by: ChaoII <849453582@qq.com> * modify vsr trt example Signed-off-by: ChaoII <849453582@qq.com> * update sr directory * fix BindPPSR * add doxygen comment * add sr unit test * update model file url Signed-off-by: ChaoII <849453582@qq.com> Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
86
examples/vision/sr/ppmsvsr/python/infer.py
Normal file
86
examples/vision/sr/ppmsvsr/python/infer.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import cv2
|
||||
import os
|
||||
import fastdeploy as fd
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
import argparse
|
||||
import ast
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", required=True, help="Path of model.")
|
||||
parser.add_argument(
|
||||
"--video", type=str, required=True, help="Path of test video file.")
|
||||
parser.add_argument("--frame_num", type=int, default=2, help="frame num")
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default='cpu',
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--use_trt",
|
||||
type=ast.literal_eval,
|
||||
default=False,
|
||||
help="Wether to use tensorrt.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_option(args):
|
||||
option = fd.RuntimeOption()
|
||||
if args.device.lower() == "gpu":
|
||||
option.use_gpu()
|
||||
if args.use_trt:
|
||||
option.use_trt_backend()
|
||||
option.enable_paddle_to_trt()
|
||||
return option
|
||||
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
# 配置runtime,加载模型
|
||||
runtime_option = build_option(args)
|
||||
model_file = os.path.join(args.model, "model.pdmodel")
|
||||
params_file = os.path.join(args.model, "model.pdiparams")
|
||||
model = fd.vision.sr.PPMSVSR(
|
||||
model_file, params_file, runtime_option=runtime_option)
|
||||
|
||||
# 该处应该与你导出模型的第二个维度一致模型输入shape=[b,n,c,h,w]
|
||||
capture = cv2.VideoCapture(args.video)
|
||||
video_out_name = "output.mp4"
|
||||
video_fps = capture.get(cv2.CAP_PROP_FPS)
|
||||
video_frame_count = capture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
# 注意导出模型时尺寸与原始输入的分辨一致比如:[1,2,3,180,320],经过4x超分后[1,2,3,720,1280]
|
||||
# 所以导出模型相当重要
|
||||
out_width = 1280
|
||||
out_height = 720
|
||||
print(f"fps: {video_fps}\tframe_count: {video_frame_count}")
|
||||
# Create VideoWriter for output
|
||||
video_out_dir = "./"
|
||||
video_out_path = os.path.join(video_out_dir, video_out_name)
|
||||
fucc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
video_out = cv2.VideoWriter(video_out_path, fucc, video_fps,
|
||||
(out_width, out_height), True)
|
||||
if not video_out.isOpened():
|
||||
print("create video writer failed!")
|
||||
# Capture all frames and do inference
|
||||
frame_id = 0
|
||||
reach_end = False
|
||||
while capture.isOpened():
|
||||
imgs = []
|
||||
for i in range(args.frame_num):
|
||||
_, frame = capture.read()
|
||||
if frame is not None:
|
||||
imgs.append(frame)
|
||||
else:
|
||||
reach_end = True
|
||||
if reach_end:
|
||||
break
|
||||
results = model.predict(imgs)
|
||||
for item in results:
|
||||
# cv2.imshow("13", item)
|
||||
# cv2.waitKey(30)
|
||||
video_out.write(item)
|
||||
print("Processing frame: ", frame_id)
|
||||
frame_id += 1
|
||||
print("inference finished, output video saved at: ", video_out_path)
|
||||
capture.release()
|
||||
video_out.release()
|
Reference in New Issue
Block a user