fix trt dy shape

This commit is contained in:
zhoushunjie
2022-11-04 03:47:47 +00:00
parent 3017ec487c
commit 2ac94e91be

View File

@@ -84,23 +84,23 @@ def build_option(args):
trt_file = os.path.join(args.model_dir, "infer.trt")
option.set_trt_input_shape(
'input_ids',
min_shape=[1, args.max_length],
opt_shape=[args.batch_size, args.max_length],
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
option.set_trt_input_shape(
'token_type_ids',
min_shape=[1, args.max_length],
opt_shape=[args.batch_size, args.max_length],
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
option.set_trt_input_shape(
'pos_ids',
min_shape=[1, args.max_length],
opt_shape=[args.batch_size, args.max_length],
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
option.set_trt_input_shape(
'att_mask',
min_shape=[1, args.max_length],
opt_shape=[args.batch_size, args.max_length],
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
if args.use_fp16:
option.enable_trt_fp16()