update code comments

This commit is contained in:
kevincheng2
2024-09-02 13:48:47 +08:00
parent 8500f5dfec
commit 4bc8dc38b0
23 changed files with 750 additions and 577 deletions

View File

@@ -60,28 +60,25 @@ class ChatBotClass(object):
Streaming interface
Args:
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
topp (float, optional): 控制随机性参数数值越大则随机性越大范围是0~1. Defaults to 0.7.
temperature (float, optional): 温度值. Defaults to 0.95.
frequency_score (float, optional): 频率分数. Defaults to 0.0.
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
presence_score (float, optional): 存在分数. Defaults to 0.0.
system (str, optional): 系统设定. Defaults to None.
**kwargs: 其他参数
req_id (str, optional): 请求ID用于区分不同的请求. Defaults to None.
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
benchmark (bool, optional): 设置benchmark模式如果是则返回完整的response. Defaults to False.
timeout (int, optional): 请求超时时间不设置则使用120s. Defaults to None.
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
max_dec_len (int, optional): max decoding length. Defaults to 1024.
min_dec_len (int, optional): min decoding length. Defaults to 1.
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
temperature (float, optional): temperature. Defaults to 0.95.
frequency_score (float, optional): frequency score. Defaults to 0.0.
penalty_score (float, optional): penalty score. Defaults to 1.0.
presence_score (float, optional): presence score. Defaults to 0.0.
system (str, optional): system settings. Defaults to None.
**kwargs: others
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
Returns:
返回一个生成器每次yield返回一个字典。
正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0}其中token为生成的字符is_end表明是否为最后一个字符0表示否1表示是
错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
return a generator object, which yields a dict.
Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
"""
try:
# 准备输入
model_name = "model"
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
outputs = [grpcclient.InferRequestedOutput("OUT")]
@@ -96,14 +93,11 @@ class ChatBotClass(object):
timeout = kwargs.get("timeout", self.timeout)
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
# 建立连接
triton_client.start_stream(callback=partial(triton_callback, output_data))
# 发送请求
triton_client.async_stream_infer(model_name=model_name,
inputs=inputs,
request_id=req_id,
outputs=outputs)
# 处理结果
answer_str = ""
enable_benchmark = is_enable_benchmark(**kwargs)
while True:
@@ -129,7 +123,6 @@ class ChatBotClass(object):
yield response
if response.get("is_end") == 1 or response.get("error_msg") is not None:
break
# 手动关闭
triton_client.stop_stream(cancel_requests=True)
triton_client.close()
@@ -150,27 +143,26 @@ class ChatBotClass(object):
system=None,
**kwargs):
"""
整句返回,直接使用流式返回的接口。
Return the entire sentence using the streaming interface.
Args:
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
topp (float, optional): 控制随机性参数数值越大则随机性越大范围是0~1. Defaults to 0.7.
temperature (float, optional): 温度值. Defaults to 0.95.
frequency_score (float, optional): 频率分数. Defaults to 0.0.
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
presence_score (float, optional): 存在分数. Defaults to 0.0.
system (str, optional): 系统设定. Defaults to None.
**kwargs: 其他参数
req_id (str, optional): 请求ID用于区分不同的请求. Defaults to None.
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
timeout (int, optional): 请求超时时间不设置则使用120s. Defaults to None.
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
max_dec_len (int, optional): max decoding length. Defaults to 1024.
min_dec_len (int, optional): min decoding length. Defaults to 1.
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
temperature (float, optional): temperature. Defaults to 0.95.
frequency_score (float, optional): frequency score. Defaults to 0.0.
penalty_score (float, optional): penalty score. Defaults to 1.0.
presence_score (float, optional): presence score. Defaults to 0.0.
system (str, optional): system settings. Defaults to None.
**kwargs: others
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
Returns:
返回一个字典。
正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
return the entire sentence or error message.
Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
"""
stream_response = self.stream_generate(message, max_dec_len,
min_dec_len, topp, temperature,
@@ -205,7 +197,7 @@ class ChatBotClass(object):
system=None,
**kwargs):
"""
准备输入数据。
Prepare to input data
"""
inputs = {
"max_dec_len": max_dec_len,
@@ -248,7 +240,7 @@ class ChatBotClass(object):
def _format_response(self, response, req_id):
"""
对服务返回字段进行格式化
Format the service return fields
"""
response = json.loads(response.as_numpy("OUT")[0])
if isinstance(response, (list, tuple)):
@@ -273,13 +265,17 @@ class ChatBotClass(object):
class OutputData:
"""接收Triton服务返回的数据"""
"""
Receive data returned by Triton service
"""
def __init__(self):
self._completed_requests = queue.Queue()
def triton_callback(output_data, result, error):
"""Triton客户端的回调函数"""
"""
callback function for Triton server
"""
if error:
output_data._completed_requests.put(error)
else:
@@ -288,17 +284,17 @@ def triton_callback(output_data, result, error):
class ChatBot(object):
"""
对外的接口,用于创建ChatBotForPushMode的示例
External interface, create a client object ChatBotForPushMode
"""
def __new__(cls, hostname, port, timeout=120):
"""
初始化函数,用于创建一个GRPCInferenceService客户端对象
initialize a GRPCInferenceService client
Args:
hostname (str): 服务器的地址
port (int): 服务器的端口号
timeout (int): 请求超时时间单位为秒默认120秒
hostname (str): server hostname
port (int): GRPC port
timeout (int): timeout(s), default 120 seconds
Returns:
ChatBotClass: 返回一个BaseChatBot对象
ChatBotClass: BaseChatBot object
"""
if not isinstance(hostname, str) or not hostname:
raise ValueError("Invalid hostname")