mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 01:22:59 +08:00
@@ -15,6 +15,7 @@ import fastdeploy
|
||||
from fastdeploy.text import UIEModel, SchemaLanguage
|
||||
import os
|
||||
from pprint import pprint
|
||||
import distutils.util
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@@ -31,17 +32,34 @@ def parse_arguments():
|
||||
default='cpu',
|
||||
choices=['cpu', 'gpu'],
|
||||
help="Type of inference device, support 'cpu' or 'gpu'.")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=1, help="The batch size of data.")
|
||||
parser.add_argument(
|
||||
"--device_id", type=int, default=0, help="device(gpu) id")
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=128,
|
||||
help="The max length of sequence.")
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
type=str,
|
||||
default='onnx_runtime',
|
||||
choices=['onnx_runtime', 'paddle_inference', 'openvino'],
|
||||
default='paddle_inference',
|
||||
choices=[
|
||||
'onnx_runtime', 'paddle_inference', 'openvino', 'paddle_tensorrt',
|
||||
'tensorrt'
|
||||
],
|
||||
help="The inference runtime backend.")
|
||||
parser.add_argument(
|
||||
"--cpu_num_threads",
|
||||
type=int,
|
||||
default=8,
|
||||
help="The number of threads to execute inference in cpu device.")
|
||||
parser.add_argument(
|
||||
"--use_fp16",
|
||||
type=distutils.util.strtobool,
|
||||
default=False,
|
||||
help="Use FP16 mode")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -50,8 +68,9 @@ def build_option(args):
|
||||
# Set device
|
||||
if args.device == 'cpu':
|
||||
runtime_option.use_cpu()
|
||||
runtime_option.set_cpu_thread_num(args.cpu_num_threads)
|
||||
else:
|
||||
runtime_option.use_gpu()
|
||||
runtime_option.use_gpu(args.device_id)
|
||||
|
||||
# Set backend
|
||||
if args.backend == 'onnx_runtime':
|
||||
@@ -60,7 +79,37 @@ def build_option(args):
|
||||
runtime_option.use_paddle_infer_backend()
|
||||
elif args.backend == 'openvino':
|
||||
runtime_option.use_openvino_backend()
|
||||
runtime_option.set_cpu_thread_num(args.cpu_num_threads)
|
||||
else:
|
||||
runtime_option.use_trt_backend()
|
||||
if args.backend == 'paddle_tensorrt':
|
||||
runtime_option.enable_paddle_to_trt()
|
||||
runtime_option.enable_paddle_trt_collect_shape()
|
||||
# Only useful for single stage predict
|
||||
runtime_option.set_trt_input_shape(
|
||||
'input_ids',
|
||||
min_shape=[1, 1],
|
||||
opt_shape=[args.batch_size, args.max_length // 2],
|
||||
max_shape=[args.batch_size, args.max_length])
|
||||
runtime_option.set_trt_input_shape(
|
||||
'token_type_ids',
|
||||
min_shape=[1, 1],
|
||||
opt_shape=[args.batch_size, args.max_length // 2],
|
||||
max_shape=[args.batch_size, args.max_length])
|
||||
runtime_option.set_trt_input_shape(
|
||||
'pos_ids',
|
||||
min_shape=[1, 1],
|
||||
opt_shape=[args.batch_size, args.max_length // 2],
|
||||
max_shape=[args.batch_size, args.max_length])
|
||||
runtime_option.set_trt_input_shape(
|
||||
'att_mask',
|
||||
min_shape=[1, 1],
|
||||
opt_shape=[args.batch_size, args.max_length // 2],
|
||||
max_shape=[args.batch_size, args.max_length])
|
||||
trt_file = os.path.join(args.model_dir, "inference.trt")
|
||||
if args.use_fp16:
|
||||
runtime_option.enable_trt_fp16()
|
||||
trt_file = trt_file + ".fp16"
|
||||
runtime_option.set_trt_cache_file(trt_file)
|
||||
return runtime_option
|
||||
|
||||
|
||||
@@ -78,7 +127,7 @@ if __name__ == "__main__":
|
||||
param_path,
|
||||
vocab_path,
|
||||
position_prob=0.5,
|
||||
max_length=128,
|
||||
max_length=args.max_length,
|
||||
schema=schema,
|
||||
runtime_option=runtime_option,
|
||||
schema_language=SchemaLanguage.ZH)
|
||||
@@ -132,8 +181,7 @@ if __name__ == "__main__":
|
||||
schema = {"评价维度": ["观点词", "情感倾向[正向,负向]"]}
|
||||
print(f"The extraction schema: {schema}")
|
||||
uie.set_schema(schema)
|
||||
results = uie.predict(
|
||||
["店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"], return_dict=True)
|
||||
results = uie.predict(["店面干净,很清静"], return_dict=True)
|
||||
pprint(results)
|
||||
print()
|
||||
|
||||
|
Reference in New Issue
Block a user