[LLM] Support deploy LLM model

This commit is contained in:
chengyanfu
2024-08-29 19:38:11 +08:00
parent cd0ee79c91
commit 2d7d0ee92e
37 changed files with 4521 additions and 3 deletions

View File

@@ -0,0 +1,308 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import queue
import traceback
import uuid
from functools import partial
import numpy as np
import tritonclient.grpc as grpcclient
from fastdeploy_client.message import ChatMessage
from fastdeploy_client.utils import is_enable_benchmark
from tritonclient import utils as triton_utils
class ChatBotClass(object):
"""
initiating conversations through the tritonclient interface of the model service.
"""
def __init__(self, hostname, port, timeout=120):
"""
Initialization function
Args:
hostname (str): gRPC hostname
port (int): gRPC port
timeout (int): Request timeout, default is 120 seconds.
Returns:
None
"""
self.url = f"{hostname}:{port}"
self.timeout = timeout
def stream_generate(self,
message,
max_dec_len=1024,
min_dec_len=1,
topp=0.7,
temperature=0.95,
frequency_score=0.0,
penalty_score=1.0,
presence_score=0.0,
system=None,
**kwargs):
"""
Streaming interface
Args:
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
topp (float, optional): 控制随机性参数数值越大则随机性越大范围是0~1. Defaults to 0.7.
temperature (float, optional): 温度值. Defaults to 0.95.
frequency_score (float, optional): 频率分数. Defaults to 0.0.
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
presence_score (float, optional): 存在分数. Defaults to 0.0.
system (str, optional): 系统设定. Defaults to None.
**kwargs: 其他参数
req_id (str, optional): 请求ID用于区分不同的请求. Defaults to None.
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
benchmark (bool, optional): 设置benchmark模式如果是则返回完整的response. Defaults to False.
timeout (int, optional): 请求超时时间不设置则使用120s. Defaults to None.
Returns:
返回一个生成器每次yield返回一个字典。
正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0}其中token为生成的字符is_end表明是否为最后一个字符0表示否1表示是
错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
"""
try:
# 准备输入
model_name = "model"
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
outputs = [grpcclient.InferRequestedOutput("OUT")]
output_data = OutputData()
msg = message.message if isinstance(message, ChatMessage) else message
input_data = self._prepare_input_data(msg, max_dec_len, min_dec_len,
topp, temperature, frequency_score,
penalty_score, presence_score, **kwargs)
req_id = input_data["req_id"]
inputs[0].set_data_from_numpy(np.array([json.dumps([input_data])], dtype=np.object_))
timeout = kwargs.get("timeout", self.timeout)
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
# 建立连接
triton_client.start_stream(callback=partial(triton_callback, output_data))
# 发送请求
triton_client.async_stream_infer(model_name=model_name,
inputs=inputs,
request_id=req_id,
outputs=outputs)
# 处理结果
answer_str = ""
enable_benchmark = is_enable_benchmark(**kwargs)
while True:
try:
response = output_data._completed_requests.get(timeout=timeout)
except queue.Empty:
yield {"req_id": req_id, "error_msg": f"Fetch response from server timeout ({timeout}s)"}
break
if type(response) == triton_utils.InferenceServerException:
yield {"req_id": req_id, "error_msg": f"InferenceServerException raised by inference: {response.message()}"}
break
else:
if enable_benchmark:
response = json.loads(response.as_numpy("OUT")[0])
if isinstance(response, (list, tuple)):
response = response[0]
else:
response = self._format_response(response, req_id)
token = response.get("token", "")
if isinstance(token, list):
token = token[0]
answer_str += token
yield response
if response.get("is_end") == 1 or response.get("error_msg") is not None:
break
# 手动关闭
triton_client.stop_stream(cancel_requests=True)
triton_client.close()
if isinstance(message, ChatMessage):
message.message.append({"role": "assistant", "content": answer_str})
except Exception as e:
yield {"error_msg": f"{e}, details={str(traceback.format_exc())}"}
def generate(self,
message,
max_dec_len=1024,
min_dec_len=1,
topp=0.7,
temperature=0.95,
frequency_score=0.0,
penalty_score=1.0,
presence_score=0.0,
system=None,
**kwargs):
"""
整句返回,直接使用流式返回的接口。
Args:
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
topp (float, optional): 控制随机性参数数值越大则随机性越大范围是0~1. Defaults to 0.7.
temperature (float, optional): 温度值. Defaults to 0.95.
frequency_score (float, optional): 频率分数. Defaults to 0.0.
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
presence_score (float, optional): 存在分数. Defaults to 0.0.
system (str, optional): 系统设定. Defaults to None.
**kwargs: 其他参数
req_id (str, optional): 请求ID用于区分不同的请求. Defaults to None.
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
timeout (int, optional): 请求超时时间不设置则使用120s. Defaults to None.
Returns:
返回一个字典。
正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
"""
stream_response = self.stream_generate(message, max_dec_len,
min_dec_len, topp, temperature,
frequency_score, penalty_score,
presence_score, system, **kwargs)
results = ""
token_ids = list()
error_msg = None
for res in stream_response:
if "token" not in res or "error_msg" in res:
error_msg = {"error_msg": f"response error, please check the info: {res}"}
elif isinstance(res["token"], list):
results = res["token"]
token_ids = res["token_ids"]
else:
results += res["token"]
token_ids += res["token_ids"]
if error_msg:
return {"req_id": res["req_id"], "error_msg": error_msg}
else:
return {"req_id": res["req_id"], "results": results, "token_ids": token_ids}
def _prepare_input_data(self,
message,
max_dec_len=1024,
min_dec_len=2,
topp=0.0,
temperature=1.0,
frequency_score=0.0,
penalty_score=1.0,
presence_score=0.0,
system=None,
**kwargs):
"""
准备输入数据。
"""
inputs = {
"max_dec_len": max_dec_len,
"min_dec_len": min_dec_len,
"topp": topp,
"temperature": temperature,
"frequency_score": frequency_score,
"penalty_score": penalty_score,
"presence_score": presence_score,
}
if system is not None:
inputs["system"] = system
inputs["req_id"] = kwargs.get("req_id", str(uuid.uuid4()))
if "eos_token_ids" in kwargs and kwargs["eos_token_ids"] is not None:
inputs["eos_token_ids"] = kwargs["eos_token_ids"]
inputs["response_timeout"] = kwargs.get("timeout", self.timeout)
if isinstance(message, str):
inputs["text"] = message
elif isinstance(message, list):
assert len(message) % 2 == 1, \
"The length of message should be odd while it's a list."
assert message[-1]["role"] == "user", \
"The {}-th element key should be 'user'".format(len(message) - 1)
for i in range(0, len(message) - 1, 2):
assert message[i]["role"] == "user", \
"The {}-th element key should be 'user'".format(i)
assert message[i + 1]["role"] == "assistant", \
"The {}-th element key should be 'assistant'".format(i + 1)
inputs["messages"] = message
else:
raise Exception(
"The message should be string or list of dict like [{'role': "
"'user', 'content': 'Hello, what's your name?''}]"
)
return inputs
def _format_response(self, response, req_id):
"""
对服务返回字段进行格式化
"""
response = json.loads(response.as_numpy("OUT")[0])
if isinstance(response, (list, tuple)):
response = response[0]
is_end = response.get("is_end", False)
if "error_msg" in response:
return {"req_id": req_id, "error_msg": response["error_msg"]}
elif "choices" in response:
token = [x["token"] for x in response["choices"]]
token_ids = [x["token_ids"] for x in response["choices"]]
return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": 1}
elif "token" not in response and "result" not in response:
return {"req_id": req_id, "error_msg": f"The response should contain 'token' or 'result', but got {response}"}
else:
token_ids = response.get("token_ids", [])
if "result" in response:
token = response["result"]
elif "token" in response:
token = response["token"]
return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": is_end}
class OutputData:
"""接收Triton服务返回的数据"""
def __init__(self):
self._completed_requests = queue.Queue()
def triton_callback(output_data, result, error):
"""Triton客户端的回调函数"""
if error:
output_data._completed_requests.put(error)
else:
output_data._completed_requests.put(result)
class ChatBot(object):
"""
对外的接口用于创建ChatBotForPushMode的示例
"""
def __new__(cls, hostname, port, timeout=120):
"""
初始化函数用于创建一个GRPCInferenceService客户端对象
Args:
hostname (str): 服务器的地址
port (int): 服务器的端口号
timeout (int): 请求超时时间单位为秒默认120秒
Returns:
ChatBotClass: 返回一个BaseChatBot对象
"""
if not isinstance(hostname, str) or not hostname:
raise ValueError("Invalid hostname")
if not isinstance(port, int) or port <= 0 or port > 65535:
raise ValueError("Invalid port number")
return ChatBotClass(hostname, port, timeout)