[Model] add vsr serials models (#518)

* [Model] add vsr serials models

Signed-off-by: ChaoII <849453582@qq.com>

* [Model] add vsr serials models

Signed-off-by: ChaoII <849453582@qq.com>

* fix build problem

Signed-off-by: ChaoII <849453582@qq.com>

* fix code style

Signed-off-by: ChaoII <849453582@qq.com>

* modify according to review suggestions

Signed-off-by: ChaoII <849453582@qq.com>

* modify vsr trt example

Signed-off-by: ChaoII <849453582@qq.com>

* update sr directory

* fix BindPPSR

* add doxygen comment

* add sr unit test

* update model file url

Signed-off-by: ChaoII <849453582@qq.com>
Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
ChaoII
2022-11-21 10:58:28 +08:00
committed by GitHub
parent 1ac54c96bd
commit c7ec14de95
40 changed files with 2526 additions and 8 deletions

View File

@@ -0,0 +1,86 @@
import cv2
import os
import fastdeploy as fd
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Path of model.")
parser.add_argument(
"--video", type=str, required=True, help="Path of test video file.")
parser.add_argument("--frame_num", type=int, default=2, help="frame num")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
option.enable_paddle_to_trt()
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model_file = os.path.join(args.model, "model.pdmodel")
params_file = os.path.join(args.model, "model.pdiparams")
model = fd.vision.sr.PPMSVSR(
model_file, params_file, runtime_option=runtime_option)
# 该处应该与你导出模型的第二个维度一致模型输入shape=[b,n,c,h,w]
capture = cv2.VideoCapture(args.video)
video_out_name = "output.mp4"
video_fps = capture.get(cv2.CAP_PROP_FPS)
video_frame_count = capture.get(cv2.CAP_PROP_FRAME_COUNT)
# 注意导出模型时尺寸与原始输入的分辨一致比如:[1,2,3,180,320],经过4x超分后[1,2,3,720,1280]
# 所以导出模型相当重要
out_width = 1280
out_height = 720
print(f"fps: {video_fps}\tframe_count: {video_frame_count}")
# Create VideoWriter for output
video_out_dir = "./"
video_out_path = os.path.join(video_out_dir, video_out_name)
fucc = cv2.VideoWriter_fourcc(*"mp4v")
video_out = cv2.VideoWriter(video_out_path, fucc, video_fps,
(out_width, out_height), True)
if not video_out.isOpened():
print("create video writer failed!")
# Capture all frames and do inference
frame_id = 0
reach_end = False
while capture.isOpened():
imgs = []
for i in range(args.frame_num):
_, frame = capture.read()
if frame is not None:
imgs.append(frame)
else:
reach_end = True
if reach_end:
break
results = model.predict(imgs)
for item in results:
# cv2.imshow("13", item)
# cv2.waitKey(30)
video_out.write(item)
print("Processing frame: ", frame_id)
frame_id += 1
print("inference finished, output video saved at: ", video_out_path)
capture.release()
video_out.release()