[Model] Add trt usage for uie (#967)

Add trt
This commit is contained in:
Jack Zhou
2022-12-26 16:38:10 +08:00
committed by GitHub
parent 1911002b90
commit df940b750f
3 changed files with 81 additions and 7 deletions

View File

@@ -15,6 +15,7 @@ import fastdeploy
from fastdeploy.text import UIEModel, SchemaLanguage
import os
from pprint import pprint
import distutils.util
def parse_arguments():
@@ -31,17 +32,34 @@ def parse_arguments():
default='cpu',
choices=['cpu', 'gpu'],
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--batch_size", type=int, default=1, help="The batch size of data.")
parser.add_argument(
"--device_id", type=int, default=0, help="device(gpu) id")
parser.add_argument(
"--max_length",
type=int,
default=128,
help="The max length of sequence.")
parser.add_argument(
"--backend",
type=str,
default='onnx_runtime',
choices=['onnx_runtime', 'paddle_inference', 'openvino'],
default='paddle_inference',
choices=[
'onnx_runtime', 'paddle_inference', 'openvino', 'paddle_tensorrt',
'tensorrt'
],
help="The inference runtime backend.")
parser.add_argument(
"--cpu_num_threads",
type=int,
default=8,
help="The number of threads to execute inference in cpu device.")
parser.add_argument(
"--use_fp16",
type=distutils.util.strtobool,
default=False,
help="Use FP16 mode")
return parser.parse_args()
@@ -50,8 +68,9 @@ def build_option(args):
# Set device
if args.device == 'cpu':
runtime_option.use_cpu()
runtime_option.set_cpu_thread_num(args.cpu_num_threads)
else:
runtime_option.use_gpu()
runtime_option.use_gpu(args.device_id)
# Set backend
if args.backend == 'onnx_runtime':
@@ -60,7 +79,37 @@ def build_option(args):
runtime_option.use_paddle_infer_backend()
elif args.backend == 'openvino':
runtime_option.use_openvino_backend()
runtime_option.set_cpu_thread_num(args.cpu_num_threads)
else:
runtime_option.use_trt_backend()
if args.backend == 'paddle_tensorrt':
runtime_option.enable_paddle_to_trt()
runtime_option.enable_paddle_trt_collect_shape()
# Only useful for single stage predict
runtime_option.set_trt_input_shape(
'input_ids',
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
runtime_option.set_trt_input_shape(
'token_type_ids',
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
runtime_option.set_trt_input_shape(
'pos_ids',
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
runtime_option.set_trt_input_shape(
'att_mask',
min_shape=[1, 1],
opt_shape=[args.batch_size, args.max_length // 2],
max_shape=[args.batch_size, args.max_length])
trt_file = os.path.join(args.model_dir, "inference.trt")
if args.use_fp16:
runtime_option.enable_trt_fp16()
trt_file = trt_file + ".fp16"
runtime_option.set_trt_cache_file(trt_file)
return runtime_option
@@ -78,7 +127,7 @@ if __name__ == "__main__":
param_path,
vocab_path,
position_prob=0.5,
max_length=128,
max_length=args.max_length,
schema=schema,
runtime_option=runtime_option,
schema_language=SchemaLanguage.ZH)
@@ -132,8 +181,7 @@ if __name__ == "__main__":
schema = {"评价维度": ["观点词", "情感倾向[正向,负向]"]}
print(f"The extraction schema: {schema}")
uie.set_schema(schema)
results = uie.predict(
["店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"], return_dict=True)
results = uie.predict(["店面干净,很清静"], return_dict=True)
pprint(results)
print()