mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
fix trt dy shape
This commit is contained in:
@@ -84,23 +84,23 @@ def build_option(args):
|
|||||||
trt_file = os.path.join(args.model_dir, "infer.trt")
|
trt_file = os.path.join(args.model_dir, "infer.trt")
|
||||||
option.set_trt_input_shape(
|
option.set_trt_input_shape(
|
||||||
'input_ids',
|
'input_ids',
|
||||||
min_shape=[1, args.max_length],
|
min_shape=[1, 1],
|
||||||
opt_shape=[args.batch_size, args.max_length],
|
opt_shape=[args.batch_size, args.max_length // 2],
|
||||||
max_shape=[args.batch_size, args.max_length])
|
max_shape=[args.batch_size, args.max_length])
|
||||||
option.set_trt_input_shape(
|
option.set_trt_input_shape(
|
||||||
'token_type_ids',
|
'token_type_ids',
|
||||||
min_shape=[1, args.max_length],
|
min_shape=[1, 1],
|
||||||
opt_shape=[args.batch_size, args.max_length],
|
opt_shape=[args.batch_size, args.max_length // 2],
|
||||||
max_shape=[args.batch_size, args.max_length])
|
max_shape=[args.batch_size, args.max_length])
|
||||||
option.set_trt_input_shape(
|
option.set_trt_input_shape(
|
||||||
'pos_ids',
|
'pos_ids',
|
||||||
min_shape=[1, args.max_length],
|
min_shape=[1, 1],
|
||||||
opt_shape=[args.batch_size, args.max_length],
|
opt_shape=[args.batch_size, args.max_length // 2],
|
||||||
max_shape=[args.batch_size, args.max_length])
|
max_shape=[args.batch_size, args.max_length])
|
||||||
option.set_trt_input_shape(
|
option.set_trt_input_shape(
|
||||||
'att_mask',
|
'att_mask',
|
||||||
min_shape=[1, args.max_length],
|
min_shape=[1, 1],
|
||||||
opt_shape=[args.batch_size, args.max_length],
|
opt_shape=[args.batch_size, args.max_length // 2],
|
||||||
max_shape=[args.batch_size, args.max_length])
|
max_shape=[args.batch_size, args.max_length])
|
||||||
if args.use_fp16:
|
if args.use_fp16:
|
||||||
option.enable_trt_fp16()
|
option.enable_trt_fp16()
|
||||||
|
Reference in New Issue
Block a user