From 2ac94e91beca5ed974c9a41189dffb619267a7ba Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Fri, 4 Nov 2022 03:47:47 +0000 Subject: [PATCH] fix trt dy shape --- benchmark/benchmark_uie.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/benchmark/benchmark_uie.py b/benchmark/benchmark_uie.py index d7f74a048..a97bb026f 100644 --- a/benchmark/benchmark_uie.py +++ b/benchmark/benchmark_uie.py @@ -84,23 +84,23 @@ def build_option(args): trt_file = os.path.join(args.model_dir, "infer.trt") option.set_trt_input_shape( 'input_ids', - min_shape=[1, args.max_length], - opt_shape=[args.batch_size, args.max_length], + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], max_shape=[args.batch_size, args.max_length]) option.set_trt_input_shape( 'token_type_ids', - min_shape=[1, args.max_length], - opt_shape=[args.batch_size, args.max_length], + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], max_shape=[args.batch_size, args.max_length]) option.set_trt_input_shape( 'pos_ids', - min_shape=[1, args.max_length], - opt_shape=[args.batch_size, args.max_length], + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], max_shape=[args.batch_size, args.max_length]) option.set_trt_input_shape( 'att_mask', - min_shape=[1, args.max_length], - opt_shape=[args.batch_size, args.max_length], + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], max_shape=[args.batch_size, args.max_length]) if args.use_fp16: option.enable_trt_fp16()