mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[LLM] Support deploy LLM model
This commit is contained in:
308
llm/client/fastdeploy_client/chatbot.py
Normal file
308
llm/client/fastdeploy_client/chatbot.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import queue
|
||||
import traceback
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from fastdeploy_client.message import ChatMessage
|
||||
from fastdeploy_client.utils import is_enable_benchmark
|
||||
from tritonclient import utils as triton_utils
|
||||
|
||||
|
||||
class ChatBotClass(object):
|
||||
"""
|
||||
initiating conversations through the tritonclient interface of the model service.
|
||||
"""
|
||||
def __init__(self, hostname, port, timeout=120):
|
||||
"""
|
||||
Initialization function
|
||||
|
||||
Args:
|
||||
hostname (str): gRPC hostname
|
||||
port (int): gRPC port
|
||||
timeout (int): Request timeout, default is 120 seconds.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.url = f"{hostname}:{port}"
|
||||
self.timeout = timeout
|
||||
|
||||
def stream_generate(self,
|
||||
message,
|
||||
max_dec_len=1024,
|
||||
min_dec_len=1,
|
||||
topp=0.7,
|
||||
temperature=0.95,
|
||||
frequency_score=0.0,
|
||||
penalty_score=1.0,
|
||||
presence_score=0.0,
|
||||
system=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Streaming interface
|
||||
|
||||
Args:
|
||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
||||
system (str, optional): 系统设定. Defaults to None.
|
||||
**kwargs: 其他参数
|
||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
||||
benchmark (bool, optional): 设置benchmark模式,如果是则返回完整的response. Defaults to False.
|
||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
||||
|
||||
Returns:
|
||||
返回一个生成器,每次yield返回一个字典。
|
||||
正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0},其中token为生成的字符,is_end表明是否为最后一个字符(0表示否,1表示是)
|
||||
错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
||||
"""
|
||||
try:
|
||||
# 准备输入
|
||||
model_name = "model"
|
||||
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
|
||||
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
||||
output_data = OutputData()
|
||||
|
||||
msg = message.message if isinstance(message, ChatMessage) else message
|
||||
input_data = self._prepare_input_data(msg, max_dec_len, min_dec_len,
|
||||
topp, temperature, frequency_score,
|
||||
penalty_score, presence_score, **kwargs)
|
||||
req_id = input_data["req_id"]
|
||||
inputs[0].set_data_from_numpy(np.array([json.dumps([input_data])], dtype=np.object_))
|
||||
timeout = kwargs.get("timeout", self.timeout)
|
||||
|
||||
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
|
||||
# 建立连接
|
||||
triton_client.start_stream(callback=partial(triton_callback, output_data))
|
||||
# 发送请求
|
||||
triton_client.async_stream_infer(model_name=model_name,
|
||||
inputs=inputs,
|
||||
request_id=req_id,
|
||||
outputs=outputs)
|
||||
# 处理结果
|
||||
answer_str = ""
|
||||
enable_benchmark = is_enable_benchmark(**kwargs)
|
||||
while True:
|
||||
try:
|
||||
response = output_data._completed_requests.get(timeout=timeout)
|
||||
except queue.Empty:
|
||||
yield {"req_id": req_id, "error_msg": f"Fetch response from server timeout ({timeout}s)"}
|
||||
break
|
||||
if type(response) == triton_utils.InferenceServerException:
|
||||
yield {"req_id": req_id, "error_msg": f"InferenceServerException raised by inference: {response.message()}"}
|
||||
break
|
||||
else:
|
||||
if enable_benchmark:
|
||||
response = json.loads(response.as_numpy("OUT")[0])
|
||||
if isinstance(response, (list, tuple)):
|
||||
response = response[0]
|
||||
else:
|
||||
response = self._format_response(response, req_id)
|
||||
token = response.get("token", "")
|
||||
if isinstance(token, list):
|
||||
token = token[0]
|
||||
answer_str += token
|
||||
yield response
|
||||
if response.get("is_end") == 1 or response.get("error_msg") is not None:
|
||||
break
|
||||
# 手动关闭
|
||||
triton_client.stop_stream(cancel_requests=True)
|
||||
triton_client.close()
|
||||
|
||||
if isinstance(message, ChatMessage):
|
||||
message.message.append({"role": "assistant", "content": answer_str})
|
||||
except Exception as e:
|
||||
yield {"error_msg": f"{e}, details={str(traceback.format_exc())}"}
|
||||
|
||||
def generate(self,
|
||||
message,
|
||||
max_dec_len=1024,
|
||||
min_dec_len=1,
|
||||
topp=0.7,
|
||||
temperature=0.95,
|
||||
frequency_score=0.0,
|
||||
penalty_score=1.0,
|
||||
presence_score=0.0,
|
||||
system=None,
|
||||
**kwargs):
|
||||
"""
|
||||
整句返回,直接使用流式返回的接口。
|
||||
|
||||
Args:
|
||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
||||
system (str, optional): 系统设定. Defaults to None.
|
||||
**kwargs: 其他参数
|
||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
||||
|
||||
Returns:
|
||||
返回一个字典。
|
||||
正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
|
||||
错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
||||
"""
|
||||
stream_response = self.stream_generate(message, max_dec_len,
|
||||
min_dec_len, topp, temperature,
|
||||
frequency_score, penalty_score,
|
||||
presence_score, system, **kwargs)
|
||||
results = ""
|
||||
token_ids = list()
|
||||
error_msg = None
|
||||
for res in stream_response:
|
||||
if "token" not in res or "error_msg" in res:
|
||||
error_msg = {"error_msg": f"response error, please check the info: {res}"}
|
||||
elif isinstance(res["token"], list):
|
||||
results = res["token"]
|
||||
token_ids = res["token_ids"]
|
||||
else:
|
||||
results += res["token"]
|
||||
token_ids += res["token_ids"]
|
||||
if error_msg:
|
||||
return {"req_id": res["req_id"], "error_msg": error_msg}
|
||||
else:
|
||||
return {"req_id": res["req_id"], "results": results, "token_ids": token_ids}
|
||||
|
||||
def _prepare_input_data(self,
|
||||
message,
|
||||
max_dec_len=1024,
|
||||
min_dec_len=2,
|
||||
topp=0.0,
|
||||
temperature=1.0,
|
||||
frequency_score=0.0,
|
||||
penalty_score=1.0,
|
||||
presence_score=0.0,
|
||||
system=None,
|
||||
**kwargs):
|
||||
"""
|
||||
准备输入数据。
|
||||
"""
|
||||
inputs = {
|
||||
"max_dec_len": max_dec_len,
|
||||
"min_dec_len": min_dec_len,
|
||||
"topp": topp,
|
||||
"temperature": temperature,
|
||||
"frequency_score": frequency_score,
|
||||
"penalty_score": penalty_score,
|
||||
"presence_score": presence_score,
|
||||
}
|
||||
|
||||
if system is not None:
|
||||
inputs["system"] = system
|
||||
|
||||
inputs["req_id"] = kwargs.get("req_id", str(uuid.uuid4()))
|
||||
if "eos_token_ids" in kwargs and kwargs["eos_token_ids"] is not None:
|
||||
inputs["eos_token_ids"] = kwargs["eos_token_ids"]
|
||||
inputs["response_timeout"] = kwargs.get("timeout", self.timeout)
|
||||
|
||||
if isinstance(message, str):
|
||||
inputs["text"] = message
|
||||
elif isinstance(message, list):
|
||||
assert len(message) % 2 == 1, \
|
||||
"The length of message should be odd while it's a list."
|
||||
assert message[-1]["role"] == "user", \
|
||||
"The {}-th element key should be 'user'".format(len(message) - 1)
|
||||
for i in range(0, len(message) - 1, 2):
|
||||
assert message[i]["role"] == "user", \
|
||||
"The {}-th element key should be 'user'".format(i)
|
||||
assert message[i + 1]["role"] == "assistant", \
|
||||
"The {}-th element key should be 'assistant'".format(i + 1)
|
||||
inputs["messages"] = message
|
||||
else:
|
||||
raise Exception(
|
||||
"The message should be string or list of dict like [{'role': "
|
||||
"'user', 'content': 'Hello, what's your name?''}]"
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
def _format_response(self, response, req_id):
|
||||
"""
|
||||
对服务返回字段进行格式化
|
||||
"""
|
||||
response = json.loads(response.as_numpy("OUT")[0])
|
||||
if isinstance(response, (list, tuple)):
|
||||
response = response[0]
|
||||
is_end = response.get("is_end", False)
|
||||
|
||||
if "error_msg" in response:
|
||||
return {"req_id": req_id, "error_msg": response["error_msg"]}
|
||||
elif "choices" in response:
|
||||
token = [x["token"] for x in response["choices"]]
|
||||
token_ids = [x["token_ids"] for x in response["choices"]]
|
||||
return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": 1}
|
||||
elif "token" not in response and "result" not in response:
|
||||
return {"req_id": req_id, "error_msg": f"The response should contain 'token' or 'result', but got {response}"}
|
||||
else:
|
||||
token_ids = response.get("token_ids", [])
|
||||
if "result" in response:
|
||||
token = response["result"]
|
||||
elif "token" in response:
|
||||
token = response["token"]
|
||||
return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": is_end}
|
||||
|
||||
|
||||
class OutputData:
|
||||
"""接收Triton服务返回的数据"""
|
||||
def __init__(self):
|
||||
self._completed_requests = queue.Queue()
|
||||
|
||||
|
||||
def triton_callback(output_data, result, error):
|
||||
"""Triton客户端的回调函数"""
|
||||
if error:
|
||||
output_data._completed_requests.put(error)
|
||||
else:
|
||||
output_data._completed_requests.put(result)
|
||||
|
||||
|
||||
class ChatBot(object):
|
||||
"""
|
||||
对外的接口,用于创建ChatBotForPushMode的示例
|
||||
"""
|
||||
def __new__(cls, hostname, port, timeout=120):
|
||||
"""
|
||||
初始化函数,用于创建一个GRPCInferenceService客户端对象
|
||||
Args:
|
||||
hostname (str): 服务器的地址
|
||||
port (int): 服务器的端口号
|
||||
timeout (int): 请求超时时间,单位为秒,默认120秒
|
||||
Returns:
|
||||
ChatBotClass: 返回一个BaseChatBot对象
|
||||
"""
|
||||
if not isinstance(hostname, str) or not hostname:
|
||||
raise ValueError("Invalid hostname")
|
||||
if not isinstance(port, int) or port <= 0 or port > 65535:
|
||||
raise ValueError("Invalid port number")
|
||||
|
||||
return ChatBotClass(hostname, port, timeout)
|
Reference in New Issue
Block a user