mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-11 11:30:20 +08:00
[Serving]support uie model (#599)
* serving support uie model * serving support uie model * delete comment
This commit is contained in:
@@ -37,3 +37,4 @@
|
|||||||
|
|
||||||
- [Python部署](python)
|
- [Python部署](python)
|
||||||
- [C++部署](cpp)
|
- [C++部署](cpp)
|
||||||
|
- [服务化部署](serving)
|
||||||
|
139
examples/text/uie/serving/README.md
Normal file
139
examples/text/uie/serving/README.md
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# UIE 服务化部署示例
|
||||||
|
|
||||||
|
## 准备模型
|
||||||
|
|
||||||
|
下载UIE-Base模型(如果有已训练好的模型,跳过此步骤):
|
||||||
|
```bash
|
||||||
|
# 下载UIE模型文件和词表,以uie-base模型为例
|
||||||
|
wget https://bj.bcebos.com/fastdeploy/models/uie/uie-base.tgz
|
||||||
|
tar -xvfz uie-base.tgz
|
||||||
|
|
||||||
|
# 将下载的模型移动到模型仓库目录
|
||||||
|
mv uie-base/* models/uie/1/
|
||||||
|
```
|
||||||
|
|
||||||
|
模型下载移动好之后,目录结构如下:
|
||||||
|
```
|
||||||
|
models
|
||||||
|
└── uie
|
||||||
|
├── 1
|
||||||
|
│ ├── inference.pdiparams
|
||||||
|
│ ├── inference.pdmodel
|
||||||
|
│ ├── model.py
|
||||||
|
│ └── vocab.txt
|
||||||
|
└── config.pbtxt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 拉取并运行镜像
|
||||||
|
```bash
|
||||||
|
# CPU镜像, 仅支持Paddle/ONNX模型在CPU上进行服务化部署,支持的推理后端包括OpenVINO、Paddle Inference和ONNX Runtime
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.6.0-cpu-only-21.10
|
||||||
|
|
||||||
|
# GPU 镜像, 支持Paddle/ONNX模型在GPU/CPU上进行服务化部署,支持的推理后端包括OpenVINO、TensorRT、Paddle Inference和ONNX Runtime
|
||||||
|
docker pull paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10
|
||||||
|
|
||||||
|
# 运行容器.容器名字为 fd_serving, 并挂载当前目录为容器的 /uie_serving 目录
|
||||||
|
docker run -it --net=host --name fastdeploy_server --shm-size="1g" -v `pwd`/:/uie_serving paddlepaddle/fastdeploy:0.6.0-gpu-cuda11.4-trt8.4-21.10 bash
|
||||||
|
|
||||||
|
# 启动服务(不设置CUDA_VISIBLE_DEVICES环境变量,会拥有所有GPU卡的调度权限)
|
||||||
|
CUDA_VISIBLE_DEVICES=0 fastdeployserver --model-repository=/uie_serving/models --backend-config=python,shm-default-byte-size=10485760
|
||||||
|
```
|
||||||
|
|
||||||
|
>> **注意**: 当出现"Address already in use", 请使用`--grpc-port`指定端口号来启动服务,同时更改grpc_client.py中的请求端口号
|
||||||
|
|
||||||
|
服务启动成功后, 会有以下输出:
|
||||||
|
```
|
||||||
|
......
|
||||||
|
I0928 04:51:15.784517 206 grpc_server.cc:4117] Started GRPCInferenceService at 0.0.0.0:8001
|
||||||
|
I0928 04:51:15.785177 206 http_server.cc:2815] Started HTTPService at 0.0.0.0:8000
|
||||||
|
I0928 04:51:15.826578 206 http_server.cc:167] Started Metrics Service at 0.0.0.0:8002
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 客户端请求
|
||||||
|
客户端请求可以在本地执行脚本请求;也可以在容器中执行。
|
||||||
|
|
||||||
|
本地执行脚本需要先安装依赖:
|
||||||
|
```
|
||||||
|
pip install grpcio
|
||||||
|
pip install tritonclient[all]
|
||||||
|
|
||||||
|
# 如果bash无法识别括号,可以使用如下指令安装:
|
||||||
|
pip install tritonclient\[all\]
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
python3 grpc_client.py
|
||||||
|
```
|
||||||
|
|
||||||
|
发送请求成功后,会返回结果并打印输出:
|
||||||
|
```
|
||||||
|
1. Named Entity Recognition Task--------------
|
||||||
|
The extraction schema: ['时间', '选手', '赛事名称']
|
||||||
|
text= ['2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!']
|
||||||
|
results:
|
||||||
|
{'时间': {'end': 6,
|
||||||
|
'probability': 0.9857379794120789,
|
||||||
|
'start': 0,
|
||||||
|
'text': '2月8日上午'},
|
||||||
|
'赛事名称': {'end': 23,
|
||||||
|
'probability': 0.8503087162971497,
|
||||||
|
'start': 6,
|
||||||
|
'text': '北京冬奥会自由式滑雪女子大跳台决赛'},
|
||||||
|
'选手': {'end': 31,
|
||||||
|
'probability': 0.8981545567512512,
|
||||||
|
'start': 28,
|
||||||
|
'text': '谷爱凌'}}
|
||||||
|
================================================
|
||||||
|
text= ['2月7日北京冬奥会短道速滑男子1000米决赛中任子威获得冠军!']
|
||||||
|
results:
|
||||||
|
{'时间': {'end': 4,
|
||||||
|
'probability': 0.9921242594718933,
|
||||||
|
'start': 0,
|
||||||
|
'text': '2月7日'},
|
||||||
|
'赛事名称': {'end': 22,
|
||||||
|
'probability': 0.8171929121017456,
|
||||||
|
'start': 4,
|
||||||
|
'text': '北京冬奥会短道速滑男子1000米决赛'},
|
||||||
|
'选手': {'end': 26,
|
||||||
|
'probability': 0.9821093678474426,
|
||||||
|
'start': 23,
|
||||||
|
'text': '任子威'}}
|
||||||
|
|
||||||
|
2. Relation Extraction Task
|
||||||
|
The extraction schema: {'竞赛名称': ['主办方', '承办方', '已举办次数']}
|
||||||
|
text= ['2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。']
|
||||||
|
results:
|
||||||
|
{'竞赛名称': {'end': 13,
|
||||||
|
'probability': 0.7825395464897156,
|
||||||
|
'relation': {'主办方': [{'end': 22,
|
||||||
|
'probability': 0.8421710729598999,
|
||||||
|
'start': 14,
|
||||||
|
'text': '中国中文信息学会'},
|
||||||
|
{'end': 30,
|
||||||
|
'probability': 0.7580801248550415,
|
||||||
|
'start': 23,
|
||||||
|
'text': '中国计算机学会'}],
|
||||||
|
'已举办次数': [{'end': 82,
|
||||||
|
'probability': 0.4671308398246765,
|
||||||
|
'start': 80,
|
||||||
|
'text': '4届'}],
|
||||||
|
'承办方': [{'end': 39,
|
||||||
|
'probability': 0.8292703628540039,
|
||||||
|
'start': 35,
|
||||||
|
'text': '百度公司'},
|
||||||
|
{'end': 55,
|
||||||
|
'probability': 0.7000497579574585,
|
||||||
|
'start': 40,
|
||||||
|
'text': '中国中文信息学会评测工作委员会'},
|
||||||
|
{'end': 72,
|
||||||
|
'probability': 0.6193480491638184,
|
||||||
|
'start': 56,
|
||||||
|
'text': '中国计算机学会自然语言处理专委会'}]},
|
||||||
|
'start': 0,
|
||||||
|
'text': '2022语言与智能技术竞赛'}}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 配置修改
|
||||||
|
|
||||||
|
当前默认配置在GPU上运行Paddle引擎,如果要在CPU/GPU或其他推理引擎上运行, 需要修改配置,详情请参考[配置文档](../../../../serving/docs/zh_CN/model_configuration.md)
|
151
examples/text/uie/serving/grpc_client.py
Executable file
151
examples/text/uie/serving/grpc_client.py
Executable file
@@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional
|
||||||
|
import json
|
||||||
|
import ast
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
from tritonclient import utils as client_utils
|
||||||
|
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput, service_pb2_grpc, service_pb2
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("run_inference_on_triton")
|
||||||
|
|
||||||
|
|
||||||
|
class SyncGRPCTritonRunner:
|
||||||
|
DEFAULT_MAX_RESP_WAIT_S = 120
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
model_name: str,
|
||||||
|
model_version: str,
|
||||||
|
*,
|
||||||
|
verbose=False,
|
||||||
|
resp_wait_s: Optional[float]=None, ):
|
||||||
|
self._server_url = server_url
|
||||||
|
self._model_name = model_name
|
||||||
|
self._model_version = model_version
|
||||||
|
self._verbose = verbose
|
||||||
|
self._response_wait_t = self.DEFAULT_MAX_RESP_WAIT_S if resp_wait_s is None else resp_wait_s
|
||||||
|
|
||||||
|
self._client = InferenceServerClient(
|
||||||
|
self._server_url, verbose=self._verbose)
|
||||||
|
error = self._verify_triton_state(self._client)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Could not communicate to Triton Server: {error}")
|
||||||
|
|
||||||
|
LOGGER.debug(
|
||||||
|
f"Triton server {self._server_url} and model {self._model_name}:{self._model_version} "
|
||||||
|
f"are up and ready!")
|
||||||
|
|
||||||
|
model_config = self._client.get_model_config(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
model_metadata = self._client.get_model_metadata(self._model_name,
|
||||||
|
self._model_version)
|
||||||
|
LOGGER.info(f"Model config {model_config}")
|
||||||
|
LOGGER.info(f"Model metadata {model_metadata}")
|
||||||
|
|
||||||
|
self._inputs = {tm.name: tm for tm in model_metadata.inputs}
|
||||||
|
self._input_names = list(self._inputs)
|
||||||
|
self._outputs = {tm.name: tm for tm in model_metadata.outputs}
|
||||||
|
self._output_names = list(self._outputs)
|
||||||
|
self._outputs_req = [
|
||||||
|
InferRequestedOutput(name) for name in self._outputs
|
||||||
|
]
|
||||||
|
|
||||||
|
def Run(self, inputs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
inputs: list, Each value corresponds to an input name of self._input_names
|
||||||
|
Returns:
|
||||||
|
results: dict, {name : numpy.array}
|
||||||
|
"""
|
||||||
|
infer_inputs = []
|
||||||
|
for idx, data in enumerate(inputs):
|
||||||
|
data = json.dumps(data)
|
||||||
|
data = np.array([[data], ], dtype=np.object_)
|
||||||
|
infer_input = InferInput(self._input_names[idx], data.shape,
|
||||||
|
"BYTES")
|
||||||
|
infer_input.set_data_from_numpy(data)
|
||||||
|
infer_inputs.append(infer_input)
|
||||||
|
|
||||||
|
results = self._client.infer(
|
||||||
|
model_name=self._model_name,
|
||||||
|
model_version=self._model_version,
|
||||||
|
inputs=infer_inputs,
|
||||||
|
outputs=self._outputs_req,
|
||||||
|
client_timeout=self._response_wait_t, )
|
||||||
|
# only one output
|
||||||
|
results = results.as_numpy(self._output_names[0])
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _verify_triton_state(self, triton_client):
|
||||||
|
if not triton_client.is_server_live():
|
||||||
|
return f"Triton server {self._server_url} is not live"
|
||||||
|
elif not triton_client.is_server_ready():
|
||||||
|
return f"Triton server {self._server_url} is not ready"
|
||||||
|
elif not triton_client.is_model_ready(self._model_name,
|
||||||
|
self._model_version):
|
||||||
|
return f"Model {self._model_name}:{self._model_version} is not ready"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name = "uie"
|
||||||
|
model_version = "1"
|
||||||
|
url = "localhost:8001"
|
||||||
|
runner = SyncGRPCTritonRunner(url, model_name, model_version)
|
||||||
|
|
||||||
|
print("1. Named Entity Recognition Task--------------")
|
||||||
|
schema = ["时间", "选手", "赛事名称"]
|
||||||
|
print(f"The extraction schema: {schema}")
|
||||||
|
text = ["2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!"]
|
||||||
|
print("text=", text)
|
||||||
|
print("results:")
|
||||||
|
results = runner.Run([text, schema])
|
||||||
|
for result in results:
|
||||||
|
result = result.decode('utf-8')
|
||||||
|
result = ast.literal_eval(result)
|
||||||
|
pprint(result)
|
||||||
|
|
||||||
|
print("================================================")
|
||||||
|
text = ["2月7日北京冬奥会短道速滑男子1000米决赛中任子威获得冠军!"]
|
||||||
|
print("text=", text)
|
||||||
|
# while schema is empty, use the schema set up last time.
|
||||||
|
schema = []
|
||||||
|
results = runner.Run([text, schema])
|
||||||
|
print("results:")
|
||||||
|
for result in results:
|
||||||
|
result = result.decode('utf-8')
|
||||||
|
result = ast.literal_eval(result)
|
||||||
|
pprint(result)
|
||||||
|
|
||||||
|
print("\n2. Relation Extraction Task")
|
||||||
|
schema = {"竞赛名称": ["主办方", "承办方", "已举办次数"]}
|
||||||
|
print(f"The extraction schema: {schema}")
|
||||||
|
text = [
|
||||||
|
"2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作"
|
||||||
|
"委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。"
|
||||||
|
]
|
||||||
|
print("text=", text)
|
||||||
|
print("results:")
|
||||||
|
results = runner.Run([text, schema])
|
||||||
|
for result in results:
|
||||||
|
result = result.decode('utf-8')
|
||||||
|
result = ast.literal_eval(result)
|
||||||
|
pprint(result)
|
156
examples/text/uie/serving/models/uie/1/model.py
Normal file
156
examples/text/uie/serving/models/uie/1/model.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
import fastdeploy
|
||||||
|
from fastdeploy.text import UIEModel, SchemaLanguage
|
||||||
|
|
||||||
|
# triton_python_backend_utils is available in every Triton Python model. You
|
||||||
|
# need to use this module to create inference requests and responses. It also
|
||||||
|
# contains some utility functions for extracting information from model_config
|
||||||
|
# and converting Triton input/output types to numpy types.
|
||||||
|
import triton_python_backend_utils as pb_utils
|
||||||
|
|
||||||
|
|
||||||
|
class TritonPythonModel:
|
||||||
|
"""Your Python model must use the same class name. Every Python model
|
||||||
|
that is created must have "TritonPythonModel" as the class name.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initialize(self, args):
|
||||||
|
"""`initialize` is called only once when the model is being loaded.
|
||||||
|
Implementing `initialize` function is optional. This function allows
|
||||||
|
the model to intialize any state associated with this model.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args : dict
|
||||||
|
Both keys and values are strings. The dictionary keys and values are:
|
||||||
|
* model_config: A JSON string containing the model configuration
|
||||||
|
* model_instance_kind: A string containing model instance kind
|
||||||
|
* model_instance_device_id: A string containing model instance device ID
|
||||||
|
* model_repository: Model repository path
|
||||||
|
* model_version: Model version
|
||||||
|
* model_name: Model name
|
||||||
|
"""
|
||||||
|
# You must parse model_config. JSON string is not parsed here
|
||||||
|
self.model_config = json.loads(args['model_config'])
|
||||||
|
print("model_config:", self.model_config)
|
||||||
|
|
||||||
|
self.input_names = []
|
||||||
|
for input_config in self.model_config["input"]:
|
||||||
|
self.input_names.append(input_config["name"])
|
||||||
|
print("input:", self.input_names)
|
||||||
|
|
||||||
|
self.output_names = []
|
||||||
|
self.output_dtype = []
|
||||||
|
for output_config in self.model_config["output"]:
|
||||||
|
self.output_names.append(output_config["name"])
|
||||||
|
dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
|
||||||
|
self.output_dtype.append(dtype)
|
||||||
|
print("output:", self.output_names)
|
||||||
|
|
||||||
|
#Init fastdeploy.RuntimeOption
|
||||||
|
runtime_option = fastdeploy.RuntimeOption()
|
||||||
|
options = None
|
||||||
|
if (args['model_instance_kind'] == 'GPU'):
|
||||||
|
runtime_option.use_gpu(int(args['model_instance_device_id']))
|
||||||
|
options = self.model_config['optimization'][
|
||||||
|
'execution_accelerators']['gpu_execution_accelerator']
|
||||||
|
else:
|
||||||
|
runtime_option.use_cpu()
|
||||||
|
options = self.model_config['optimization'][
|
||||||
|
'execution_accelerators']['cpu_execution_accelerator']
|
||||||
|
|
||||||
|
for option in options:
|
||||||
|
if option['name'] == 'paddle':
|
||||||
|
runtime_option.use_paddle_backend()
|
||||||
|
elif option['name'] == 'onnxruntime':
|
||||||
|
runtime_option.use_ort_backend()
|
||||||
|
elif option['name'] == 'openvino':
|
||||||
|
runtime_option.use_openvino_backend()
|
||||||
|
|
||||||
|
if option['parameters']:
|
||||||
|
if 'cpu_threads' in option['parameters']:
|
||||||
|
runtime_option.set_cpu_thread_num(
|
||||||
|
int(option['parameters']['cpu_threads']))
|
||||||
|
|
||||||
|
model_path = os.path.abspath(os.path.dirname(
|
||||||
|
__file__)) + "/inference.pdmodel"
|
||||||
|
param_path = os.path.abspath(os.path.dirname(
|
||||||
|
__file__)) + "/inference.pdiparams"
|
||||||
|
vocab_path = os.path.abspath(os.path.dirname(__file__)) + "/vocab.txt"
|
||||||
|
schema = []
|
||||||
|
# init UIE model
|
||||||
|
self.uie_model_ = UIEModel(
|
||||||
|
model_path,
|
||||||
|
param_path,
|
||||||
|
vocab_path,
|
||||||
|
position_prob=0.5,
|
||||||
|
max_length=128,
|
||||||
|
schema=schema,
|
||||||
|
runtime_option=runtime_option,
|
||||||
|
schema_language=SchemaLanguage.ZH)
|
||||||
|
|
||||||
|
def execute(self, requests):
|
||||||
|
"""`execute` must be implemented in every Python model. `execute`
|
||||||
|
function receives a list of pb_utils.InferenceRequest as the only
|
||||||
|
argument. This function is called when an inference is requested
|
||||||
|
for this model. Depending on the batching configuration (e.g. Dynamic
|
||||||
|
Batching) used, `requests` may contain multiple requests. Every
|
||||||
|
Python model, must create one pb_utils.InferenceResponse for every
|
||||||
|
pb_utils.InferenceRequest in `requests`. If there is an error, you can
|
||||||
|
set the error argument when creating a pb_utils.InferenceResponse.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
requests : list
|
||||||
|
A list of pb_utils.InferenceRequest
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list
|
||||||
|
A list of pb_utils.InferenceResponse. The length of this list must
|
||||||
|
be the same as `requests`
|
||||||
|
"""
|
||||||
|
responses = []
|
||||||
|
for request in requests:
|
||||||
|
texts = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[0])
|
||||||
|
schema = pb_utils.get_input_tensor_by_name(request,
|
||||||
|
self.input_names[1])
|
||||||
|
texts = texts.as_numpy()
|
||||||
|
schema = schema.as_numpy()
|
||||||
|
# not support batch predict
|
||||||
|
texts = json.loads(texts[0][0])
|
||||||
|
schema = json.loads(schema[0][0])
|
||||||
|
|
||||||
|
if schema:
|
||||||
|
self.uie_model_.set_schema(schema)
|
||||||
|
results = self.uie_model_.predict(texts, return_dict=True)
|
||||||
|
|
||||||
|
results = np.array(results, dtype=np.object)
|
||||||
|
out_tensor = pb_utils.Tensor(self.output_names[0], results)
|
||||||
|
inference_response = pb_utils.InferenceResponse(
|
||||||
|
output_tensors=[out_tensor, ])
|
||||||
|
responses.append(inference_response)
|
||||||
|
return responses
|
||||||
|
|
||||||
|
def finalize(self):
|
||||||
|
"""`finalize` is called only once when the model is being unloaded.
|
||||||
|
Implementing `finalize` function is optional. This function allows
|
||||||
|
the model to perform any necessary clean ups before exit.
|
||||||
|
"""
|
||||||
|
print('Cleaning up...')
|
46
examples/text/uie/serving/models/uie/config.pbtxt
Normal file
46
examples/text/uie/serving/models/uie/config.pbtxt
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
name: "uie"
|
||||||
|
backend: "python"
|
||||||
|
max_batch_size: 1
|
||||||
|
|
||||||
|
input [
|
||||||
|
{
|
||||||
|
name: "INPUT_0"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INPUT_1"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
output [
|
||||||
|
{
|
||||||
|
name: "OUTPUT_0"
|
||||||
|
data_type: TYPE_STRING
|
||||||
|
dims: [ 1 ]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
instance_group [
|
||||||
|
{
|
||||||
|
count: 1
|
||||||
|
# Use GPU, CPU inference option is:KIND_CPU
|
||||||
|
kind: KIND_GPU
|
||||||
|
# The instance is deployed on the 0th GPU card
|
||||||
|
gpus: [0]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
optimization {
|
||||||
|
execution_accelerators {
|
||||||
|
gpu_execution_accelerator : [
|
||||||
|
{
|
||||||
|
# use paddle backend
|
||||||
|
name: "paddle"
|
||||||
|
parameters { key: "cpu_threads" value: "12" }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user