mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
update code comments
This commit is contained in:
@@ -60,28 +60,25 @@ class ChatBotClass(object):
|
||||
Streaming interface
|
||||
|
||||
Args:
|
||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
||||
system (str, optional): 系统设定. Defaults to None.
|
||||
**kwargs: 其他参数
|
||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
||||
benchmark (bool, optional): 设置benchmark模式,如果是则返回完整的response. Defaults to False.
|
||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
||||
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
|
||||
max_dec_len (int, optional): max decoding length. Defaults to 1024.
|
||||
min_dec_len (int, optional): min decoding length. Defaults to 1.
|
||||
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
|
||||
temperature (float, optional): temperature. Defaults to 0.95.
|
||||
frequency_score (float, optional): frequency score. Defaults to 0.0.
|
||||
penalty_score (float, optional): penalty score. Defaults to 1.0.
|
||||
presence_score (float, optional): presence score. Defaults to 0.0.
|
||||
system (str, optional): system settings. Defaults to None.
|
||||
**kwargs: others
|
||||
|
||||
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
|
||||
|
||||
Returns:
|
||||
返回一个生成器,每次yield返回一个字典。
|
||||
正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0},其中token为生成的字符,is_end表明是否为最后一个字符(0表示否,1表示是)
|
||||
错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
||||
return a generator object, which yields a dict.
|
||||
Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||
"""
|
||||
try:
|
||||
# 准备输入
|
||||
model_name = "model"
|
||||
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
|
||||
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
||||
@@ -96,14 +93,11 @@ class ChatBotClass(object):
|
||||
timeout = kwargs.get("timeout", self.timeout)
|
||||
|
||||
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
|
||||
# 建立连接
|
||||
triton_client.start_stream(callback=partial(triton_callback, output_data))
|
||||
# 发送请求
|
||||
triton_client.async_stream_infer(model_name=model_name,
|
||||
inputs=inputs,
|
||||
request_id=req_id,
|
||||
outputs=outputs)
|
||||
# 处理结果
|
||||
answer_str = ""
|
||||
enable_benchmark = is_enable_benchmark(**kwargs)
|
||||
while True:
|
||||
@@ -129,7 +123,6 @@ class ChatBotClass(object):
|
||||
yield response
|
||||
if response.get("is_end") == 1 or response.get("error_msg") is not None:
|
||||
break
|
||||
# 手动关闭
|
||||
triton_client.stop_stream(cancel_requests=True)
|
||||
triton_client.close()
|
||||
|
||||
@@ -150,27 +143,26 @@ class ChatBotClass(object):
|
||||
system=None,
|
||||
**kwargs):
|
||||
"""
|
||||
整句返回,直接使用流式返回的接口。
|
||||
Return the entire sentence using the streaming interface.
|
||||
|
||||
Args:
|
||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
||||
system (str, optional): 系统设定. Defaults to None.
|
||||
**kwargs: 其他参数
|
||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
||||
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
|
||||
max_dec_len (int, optional): max decoding length. Defaults to 1024.
|
||||
min_dec_len (int, optional): min decoding length. Defaults to 1.
|
||||
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
|
||||
temperature (float, optional): temperature. Defaults to 0.95.
|
||||
frequency_score (float, optional): frequency score. Defaults to 0.0.
|
||||
penalty_score (float, optional): penalty score. Defaults to 1.0.
|
||||
presence_score (float, optional): presence score. Defaults to 0.0.
|
||||
system (str, optional): system settings. Defaults to None.
|
||||
**kwargs: others
|
||||
|
||||
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
|
||||
|
||||
Returns:
|
||||
返回一个字典。
|
||||
正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
|
||||
错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
||||
return the entire sentence or error message.
|
||||
Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||
"""
|
||||
stream_response = self.stream_generate(message, max_dec_len,
|
||||
min_dec_len, topp, temperature,
|
||||
@@ -205,7 +197,7 @@ class ChatBotClass(object):
|
||||
system=None,
|
||||
**kwargs):
|
||||
"""
|
||||
准备输入数据。
|
||||
Prepare to input data
|
||||
"""
|
||||
inputs = {
|
||||
"max_dec_len": max_dec_len,
|
||||
@@ -248,7 +240,7 @@ class ChatBotClass(object):
|
||||
|
||||
def _format_response(self, response, req_id):
|
||||
"""
|
||||
对服务返回字段进行格式化
|
||||
Format the service return fields
|
||||
"""
|
||||
response = json.loads(response.as_numpy("OUT")[0])
|
||||
if isinstance(response, (list, tuple)):
|
||||
@@ -273,13 +265,17 @@ class ChatBotClass(object):
|
||||
|
||||
|
||||
class OutputData:
|
||||
"""接收Triton服务返回的数据"""
|
||||
"""
|
||||
Receive data returned by Triton service
|
||||
"""
|
||||
def __init__(self):
|
||||
self._completed_requests = queue.Queue()
|
||||
|
||||
|
||||
def triton_callback(output_data, result, error):
|
||||
"""Triton客户端的回调函数"""
|
||||
"""
|
||||
callback function for Triton server
|
||||
"""
|
||||
if error:
|
||||
output_data._completed_requests.put(error)
|
||||
else:
|
||||
@@ -288,17 +284,17 @@ def triton_callback(output_data, result, error):
|
||||
|
||||
class ChatBot(object):
|
||||
"""
|
||||
对外的接口,用于创建ChatBotForPushMode的示例
|
||||
External interface, create a client object ChatBotForPushMode
|
||||
"""
|
||||
def __new__(cls, hostname, port, timeout=120):
|
||||
"""
|
||||
初始化函数,用于创建一个GRPCInferenceService客户端对象
|
||||
initialize a GRPCInferenceService client
|
||||
Args:
|
||||
hostname (str): 服务器的地址
|
||||
port (int): 服务器的端口号
|
||||
timeout (int): 请求超时时间,单位为秒,默认120秒
|
||||
hostname (str): server hostname
|
||||
port (int): GRPC port
|
||||
timeout (int): timeout(s), default 120 seconds
|
||||
Returns:
|
||||
ChatBotClass: 返回一个BaseChatBot对象
|
||||
ChatBotClass: BaseChatBot object
|
||||
"""
|
||||
if not isinstance(hostname, str) or not hostname:
|
||||
raise ValueError("Invalid hostname")
|
||||
|
Reference in New Issue
Block a user