mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
update code comments
This commit is contained in:
@@ -60,28 +60,25 @@ class ChatBotClass(object):
|
|||||||
Streaming interface
|
Streaming interface
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
|
||||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
max_dec_len (int, optional): max decoding length. Defaults to 1024.
|
||||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
min_dec_len (int, optional): min decoding length. Defaults to 1.
|
||||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
|
||||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
temperature (float, optional): temperature. Defaults to 0.95.
|
||||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
frequency_score (float, optional): frequency score. Defaults to 0.0.
|
||||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
penalty_score (float, optional): penalty score. Defaults to 1.0.
|
||||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
presence_score (float, optional): presence score. Defaults to 0.0.
|
||||||
system (str, optional): 系统设定. Defaults to None.
|
system (str, optional): system settings. Defaults to None.
|
||||||
**kwargs: 其他参数
|
**kwargs: others
|
||||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
|
||||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
|
||||||
benchmark (bool, optional): 设置benchmark模式,如果是则返回完整的response. Defaults to False.
|
|
||||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
返回一个生成器,每次yield返回一个字典。
|
return a generator object, which yields a dict.
|
||||||
正常情况下,生成器返回字典的示例{"req_id": "xxx", "token": "好的", "is_end": 0},其中token为生成的字符,is_end表明是否为最后一个字符(0表示否,1表示是)
|
Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
错误情况下,生成器返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 准备输入
|
|
||||||
model_name = "model"
|
model_name = "model"
|
||||||
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
|
inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
|
||||||
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
||||||
@@ -96,14 +93,11 @@ class ChatBotClass(object):
|
|||||||
timeout = kwargs.get("timeout", self.timeout)
|
timeout = kwargs.get("timeout", self.timeout)
|
||||||
|
|
||||||
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
|
with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
|
||||||
# 建立连接
|
|
||||||
triton_client.start_stream(callback=partial(triton_callback, output_data))
|
triton_client.start_stream(callback=partial(triton_callback, output_data))
|
||||||
# 发送请求
|
|
||||||
triton_client.async_stream_infer(model_name=model_name,
|
triton_client.async_stream_infer(model_name=model_name,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
outputs=outputs)
|
outputs=outputs)
|
||||||
# 处理结果
|
|
||||||
answer_str = ""
|
answer_str = ""
|
||||||
enable_benchmark = is_enable_benchmark(**kwargs)
|
enable_benchmark = is_enable_benchmark(**kwargs)
|
||||||
while True:
|
while True:
|
||||||
@@ -129,7 +123,6 @@ class ChatBotClass(object):
|
|||||||
yield response
|
yield response
|
||||||
if response.get("is_end") == 1 or response.get("error_msg") is not None:
|
if response.get("is_end") == 1 or response.get("error_msg") is not None:
|
||||||
break
|
break
|
||||||
# 手动关闭
|
|
||||||
triton_client.stop_stream(cancel_requests=True)
|
triton_client.stop_stream(cancel_requests=True)
|
||||||
triton_client.close()
|
triton_client.close()
|
||||||
|
|
||||||
@@ -150,27 +143,26 @@ class ChatBotClass(object):
|
|||||||
system=None,
|
system=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
整句返回,直接使用流式返回的接口。
|
Return the entire sentence using the streaming interface.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message (Union[str, List[str], ChatMessage]): 消息内容或ChatMessage对象
|
message (Union[str, List[str], ChatMessage]): message or ChatMessage object
|
||||||
max_dec_len (int, optional): 最大解码长度. Defaults to 1024.
|
max_dec_len (int, optional): max decoding length. Defaults to 1024.
|
||||||
min_dec_len (int, optional): 最小解码长度. Defaults to 1.
|
min_dec_len (int, optional): min decoding length. Defaults to 1.
|
||||||
topp (float, optional): 控制随机性参数,数值越大则随机性越大,范围是0~1. Defaults to 0.7.
|
topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
|
||||||
temperature (float, optional): 温度值. Defaults to 0.95.
|
temperature (float, optional): temperature. Defaults to 0.95.
|
||||||
frequency_score (float, optional): 频率分数. Defaults to 0.0.
|
frequency_score (float, optional): frequency score. Defaults to 0.0.
|
||||||
penalty_score (float, optional): 惩罚分数. Defaults to 1.0.
|
penalty_score (float, optional): penalty score. Defaults to 1.0.
|
||||||
presence_score (float, optional): 存在分数. Defaults to 0.0.
|
presence_score (float, optional): presence score. Defaults to 0.0.
|
||||||
system (str, optional): 系统设定. Defaults to None.
|
system (str, optional): system settings. Defaults to None.
|
||||||
**kwargs: 其他参数
|
**kwargs: others
|
||||||
req_id (str, optional): 请求ID,用于区分不同的请求. Defaults to None.
|
|
||||||
eos_token_ids (List[int], optional): 指定结束的token id. Defaults to None.
|
For more details, please refer to https://github.com/PaddlePaddle/FastDeploy/blob/develop/llm/docs/FastDeploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
|
||||||
timeout (int, optional): 请求超时时间,不设置则使用120s. Defaults to None.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
返回一个字典。
|
return the entire sentence or error message.
|
||||||
正常情况下,返回字典的示例{"req_id": "xxx", "results": "好的,我知道了。"}
|
Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
错误情况下,返回错误信息的字典,示例 {"req_id": "xxx", "error_msg": "error message"}
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
"""
|
"""
|
||||||
stream_response = self.stream_generate(message, max_dec_len,
|
stream_response = self.stream_generate(message, max_dec_len,
|
||||||
min_dec_len, topp, temperature,
|
min_dec_len, topp, temperature,
|
||||||
@@ -205,7 +197,7 @@ class ChatBotClass(object):
|
|||||||
system=None,
|
system=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
准备输入数据。
|
Prepare to input data
|
||||||
"""
|
"""
|
||||||
inputs = {
|
inputs = {
|
||||||
"max_dec_len": max_dec_len,
|
"max_dec_len": max_dec_len,
|
||||||
@@ -248,7 +240,7 @@ class ChatBotClass(object):
|
|||||||
|
|
||||||
def _format_response(self, response, req_id):
|
def _format_response(self, response, req_id):
|
||||||
"""
|
"""
|
||||||
对服务返回字段进行格式化
|
Format the service return fields
|
||||||
"""
|
"""
|
||||||
response = json.loads(response.as_numpy("OUT")[0])
|
response = json.loads(response.as_numpy("OUT")[0])
|
||||||
if isinstance(response, (list, tuple)):
|
if isinstance(response, (list, tuple)):
|
||||||
@@ -273,13 +265,17 @@ class ChatBotClass(object):
|
|||||||
|
|
||||||
|
|
||||||
class OutputData:
|
class OutputData:
|
||||||
"""接收Triton服务返回的数据"""
|
"""
|
||||||
|
Receive data returned by Triton service
|
||||||
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._completed_requests = queue.Queue()
|
self._completed_requests = queue.Queue()
|
||||||
|
|
||||||
|
|
||||||
def triton_callback(output_data, result, error):
|
def triton_callback(output_data, result, error):
|
||||||
"""Triton客户端的回调函数"""
|
"""
|
||||||
|
callback function for Triton server
|
||||||
|
"""
|
||||||
if error:
|
if error:
|
||||||
output_data._completed_requests.put(error)
|
output_data._completed_requests.put(error)
|
||||||
else:
|
else:
|
||||||
@@ -288,17 +284,17 @@ def triton_callback(output_data, result, error):
|
|||||||
|
|
||||||
class ChatBot(object):
|
class ChatBot(object):
|
||||||
"""
|
"""
|
||||||
对外的接口,用于创建ChatBotForPushMode的示例
|
External interface, create a client object ChatBotForPushMode
|
||||||
"""
|
"""
|
||||||
def __new__(cls, hostname, port, timeout=120):
|
def __new__(cls, hostname, port, timeout=120):
|
||||||
"""
|
"""
|
||||||
初始化函数,用于创建一个GRPCInferenceService客户端对象
|
initialize a GRPCInferenceService client
|
||||||
Args:
|
Args:
|
||||||
hostname (str): 服务器的地址
|
hostname (str): server hostname
|
||||||
port (int): 服务器的端口号
|
port (int): GRPC port
|
||||||
timeout (int): 请求超时时间,单位为秒,默认120秒
|
timeout (int): timeout(s), default 120 seconds
|
||||||
Returns:
|
Returns:
|
||||||
ChatBotClass: 返回一个BaseChatBot对象
|
ChatBotClass: BaseChatBot object
|
||||||
"""
|
"""
|
||||||
if not isinstance(hostname, str) or not hostname:
|
if not isinstance(hostname, str) or not hostname:
|
||||||
raise ValueError("Invalid hostname")
|
raise ValueError("Invalid hostname")
|
||||||
|
@@ -21,7 +21,10 @@ from fastdeploy_client.chatbot import ChatBot
|
|||||||
|
|
||||||
def _get_service_configuration():
|
def _get_service_configuration():
|
||||||
"""
|
"""
|
||||||
从环境变量获取服务配置信息
|
get service url from environment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (hostname, port)
|
||||||
"""
|
"""
|
||||||
url = os.getenv("FASTDEPLOY_MODEL_URL")
|
url = os.getenv("FASTDEPLOY_MODEL_URL")
|
||||||
|
|
||||||
@@ -38,7 +41,7 @@ def _get_service_configuration():
|
|||||||
|
|
||||||
def stream_generate(prompt):
|
def stream_generate(prompt):
|
||||||
"""
|
"""
|
||||||
命令工具:流式返回
|
Streaming interface
|
||||||
"""
|
"""
|
||||||
hostname, port = _get_service_configuration()
|
hostname, port = _get_service_configuration()
|
||||||
chatbot = ChatBot(hostname=hostname, port=port)
|
chatbot = ChatBot(hostname=hostname, port=port)
|
||||||
@@ -49,7 +52,7 @@ def stream_generate(prompt):
|
|||||||
|
|
||||||
def generate(prompt):
|
def generate(prompt):
|
||||||
"""
|
"""
|
||||||
命令工具:整句返回
|
entire sentence interface
|
||||||
"""
|
"""
|
||||||
hostname, port = _get_service_configuration()
|
hostname, port = _get_service_configuration()
|
||||||
chatbot = ChatBot(hostname=hostname, port=port)
|
chatbot = ChatBot(hostname=hostname, port=port)
|
||||||
@@ -58,9 +61,6 @@ def generate(prompt):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
|
||||||
命令工具主入口
|
|
||||||
"""
|
|
||||||
if len(sys.argv) < 2 or sys.argv[1] not in ["generate", "stream_generate"]:
|
if len(sys.argv) < 2 or sys.argv[1] not in ["generate", "stream_generate"]:
|
||||||
logging.error("Usage 1: fdclient generate \"Hello, How are you?\"")
|
logging.error("Usage 1: fdclient generate \"Hello, How are you?\"")
|
||||||
return
|
return
|
||||||
|
@@ -14,8 +14,7 @@
|
|||||||
|
|
||||||
class ChatMessage(object):
|
class ChatMessage(object):
|
||||||
"""
|
"""
|
||||||
多轮对话数据结构,当使用这个与ChatBot对话时
|
multi-turn chat message with ChatBot
|
||||||
会将对话记录存储在此结构体内,支持多轮
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, prompt=None):
|
def __init__(self, prompt=None):
|
||||||
if prompt is not None:
|
if prompt is not None:
|
||||||
@@ -25,7 +24,7 @@ class ChatMessage(object):
|
|||||||
|
|
||||||
def add_user_message(self, content):
|
def add_user_message(self, content):
|
||||||
"""
|
"""
|
||||||
添加一个用户消息
|
add user message
|
||||||
"""
|
"""
|
||||||
if len(self.message) > 0 and self.message[-1]["role"] != "assistant":
|
if len(self.message) > 0 and self.message[-1]["role"] != "assistant":
|
||||||
raise Exception("Cannot add user message, because the role of the "
|
raise Exception("Cannot add user message, because the role of the "
|
||||||
@@ -34,7 +33,7 @@ class ChatMessage(object):
|
|||||||
|
|
||||||
def add_assistant_message(self, content):
|
def add_assistant_message(self, content):
|
||||||
"""
|
"""
|
||||||
添加一个assistant消息
|
add assistant message
|
||||||
"""
|
"""
|
||||||
if len(self.message) > 0 and self.message[-1]["role"] != "user":
|
if len(self.message) > 0 and self.message[-1]["role"] != "user":
|
||||||
raise Exception("Cannot add user message, because the role of the "
|
raise Exception("Cannot add user message, because the role of the "
|
||||||
@@ -43,7 +42,7 @@ class ChatMessage(object):
|
|||||||
|
|
||||||
def next_prompt(self, content):
|
def next_prompt(self, content):
|
||||||
"""
|
"""
|
||||||
添加一个新的对话,保留用于兼容。
|
add user message and return a new prompt
|
||||||
"""
|
"""
|
||||||
self.add_user_message(content)
|
self.add_user_message(content)
|
||||||
|
|
||||||
|
@@ -13,5 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
def is_enable_benchmark(**kwargs):
|
def is_enable_benchmark(**kwargs):
|
||||||
"""是否是benchmark模式"""
|
"""
|
||||||
|
Check if enable benchmark
|
||||||
|
"""
|
||||||
return "benchmark" in kwargs and kwargs["benchmark"] == 1
|
return "benchmark" in kwargs and kwargs["benchmark"] == 1
|
||||||
|
@@ -1,15 +1,16 @@
|
|||||||
FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda11.8-cudnn8-nccl2.15.5
|
FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda11.8-cudnn8-nccl2.15.5
|
||||||
|
|
||||||
WORKDIR /opt/output/
|
WORKDIR /opt/output/
|
||||||
COPY ./server/ /opt/output/Serving
|
COPY ./server/ /opt/output/Serving/
|
||||||
COPY ./client/ /opt/output/client/
|
COPY ./client/ /opt/output/client/
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
|
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
|
||||||
&& python3 -m pip install paddlenlp==3.0.0b0 \
|
&& python3 -m pip install paddlenlp==3.0.0b0 \
|
||||||
&& python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
|
&& python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH "/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
|
|
||||||
RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
|
RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
|
||||||
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
||||||
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
||||||
|
@@ -1,15 +1,16 @@
|
|||||||
FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda12.3-cudnn9-nccl2.15.5
|
FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda12.3-cudnn9-nccl2.15.5
|
||||||
|
|
||||||
WORKDIR /opt/output/
|
WORKDIR /opt/output/
|
||||||
COPY ./server/ /opt/output/Serving
|
COPY ./server/ /opt/output/Serving/
|
||||||
COPY ./client/ /opt/output/client/
|
COPY ./client/ /opt/output/client/
|
||||||
|
|
||||||
|
ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
|
||||||
|
|
||||||
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
|
RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
|
||||||
&& python3 -m pip install paddlenlp==3.0.0b0 \
|
&& python3 -m pip install paddlenlp==3.0.0b0 \
|
||||||
&& python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
|
&& python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 \
|
||||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
ENV LD_LIBRARY_PATH "/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
|
|
||||||
RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
|
RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
|
||||||
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
&& python3 setup_cuda.py build && python3 setup_cuda.py install --user \
|
||||||
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
&& cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
|
||||||
|
@@ -13,8 +13,9 @@
|
|||||||
- [服务测试](#服务测试)
|
- [服务测试](#服务测试)
|
||||||
- [Python 客户端](#Python-客户端)
|
- [Python 客户端](#Python-客户端)
|
||||||
- [HTTP调用](#HTTP调用)
|
- [HTTP调用](#HTTP调用)
|
||||||
- [请求参数介绍](#请求参数介绍)
|
|
||||||
- [返回示例](#返回示例)
|
- [返回示例](#返回示例)
|
||||||
|
- [模型配置参数介绍](#模型配置参数介绍)
|
||||||
|
- [请求参数介绍](#请求参数介绍)
|
||||||
|
|
||||||
## 部署环境准备
|
## 部署环境准备
|
||||||
|
|
||||||
@@ -33,12 +34,12 @@ cd /home/workspace/models_dir
|
|||||||
|
|
||||||
# 模型内目录结构需要整理成特定格式,如下是单卡部署的模型目录结构
|
# 模型内目录结构需要整理成特定格式,如下是单卡部署的模型目录结构
|
||||||
# /opt/output/Serving/models
|
# /opt/output/Serving/models
|
||||||
# ├── config.json # 模型配置文件(必选)
|
# ├── config.json # 模型配置文件
|
||||||
# ├── xxxx.model # 词表模型文件(必选)
|
# ├── xxxx.model # 词表模型文件
|
||||||
# ├── special_tokens_map.json # 词表配置文件(必选)
|
# ├── special_tokens_map.json # 词表配置文件
|
||||||
# ├── tokenizer_config.json # 词表配置文件(必选)
|
# ├── tokenizer_config.json # 词表配置文件
|
||||||
# ├── rank_mapping.csv # 多卡模型会有此文件,如为单卡模型,则无此文件(可选,仅在多卡部署模式下需要)
|
# ├── rank_mapping.csv # 多卡模型会有此文件,如为单卡模型,则无此文件(可选,仅在多卡部署模式下需要)
|
||||||
# └── rank_0 # 保存模型结构和权重文件的目录(必选)
|
# └── rank_0 # 保存模型结构和权重文件的目录
|
||||||
# ├── model.pdiparams
|
# ├── model.pdiparams
|
||||||
# └── model.pdmodel
|
# └── model.pdmodel
|
||||||
```
|
```
|
||||||
@@ -114,6 +115,8 @@ export MAX_CACHED_TASK_NUM="128" # 服务缓存队列最大长度,队列达
|
|||||||
export PUSH_MODE_HTTP_WORKERS="1" # HTTP服务进程数,在 PUSH_MODE_HTTP_PORT 配置的情况下有效,最高设置到8即可,默认为1
|
export PUSH_MODE_HTTP_WORKERS="1" # HTTP服务进程数,在 PUSH_MODE_HTTP_PORT 配置的情况下有效,最高设置到8即可,默认为1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
更多请求参数请参考[模型配置参数介绍](#模型配置参数介绍)
|
||||||
|
|
||||||
### 启动FastDeploy
|
### 启动FastDeploy
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -165,7 +168,7 @@ import uuid
|
|||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
url = f"http://0.0.0.0:{PUSH_MODE_HTTP_PORT}/v1/chat/completions"
|
url = f"http://127.0.0.1:{PUSH_MODE_HTTP_PORT}/v1/chat/completions"
|
||||||
req_id = str(uuid.uuid1())
|
req_id = str(uuid.uuid1())
|
||||||
data = {
|
data = {
|
||||||
"text": "Hello, how are you?",
|
"text": "Hello, how are you?",
|
||||||
@@ -179,7 +182,47 @@ for line in res.iter_lines():
|
|||||||
print(json.loads(line))
|
print(json.loads(line))
|
||||||
```
|
```
|
||||||
|
|
||||||
### 请求参数介绍
|
更多请求参数请参考[请求参数介绍](#请求参数介绍)
|
||||||
|
|
||||||
|
### 返回示例
|
||||||
|
|
||||||
|
```
|
||||||
|
如果stream为True,流式返回
|
||||||
|
如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
|
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
||||||
|
|
||||||
|
如果stream为False,非流式返回
|
||||||
|
如果正常,返回{'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
|
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型配置参数介绍
|
||||||
|
|
||||||
|
| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
|
||||||
|
| :---: | :-----: | :---: | :---: | :-----: | :----: |
|
||||||
|
| MP_NUM | int | 模型并行度 | 否 | 8 | CUDA_VISIBLE_DEVICES 需配置对应卡数 |
|
||||||
|
| CUDA_VISIBLE_DEVICES | str | 使用 GPU 编号 | 否 | 0,1,2,3,4,5,6,7 | |
|
||||||
|
| HTTP_PORT | int | 探活服务的http端口 | 是 | 无 | 当前仅用于健康检查、探活 |
|
||||||
|
| GRPC_PORT | int | 模型推服务的grpc端口 | 是 | 无 | |
|
||||||
|
| METRICS_PORT | int | 模型服务中监督指标的端口 | 是 | 无 | |
|
||||||
|
| INFER_QUEUE_PORT | int | 模型服务内部使用的端口 | 否 | 56666 | |
|
||||||
|
| PUSH_MODE_HTTP_PORT | int | 服务请求HTTP端口号 | 否 | -1 | 如不配置,服务只支持GRPC协议 |
|
||||||
|
| DISABLE_STREAMING | int | 是否使用流式返回 | 否 | 0 | |
|
||||||
|
| MAX_SEQ_LEN | int | 最大输入序列长度 | 否 | 8192 | 服务会拒绝input token数量超过MAX_SEQ_LEN的请求,并返回错误提示 |
|
||||||
|
| MAX_DEC_LEN | int | 最大decoer序列长度 | 否 | 1024 | 服务会拒绝请求中max_dec_len/min_dec_len超过此参数的请求,并返回错误提示 |
|
||||||
|
| BATCH_SIZE | int | 最大Batch Size | 否 | 50 | 模型可同时并发处理的最大输入数量,不能高于128 |
|
||||||
|
| BLOCK_BS | int | 缓存Block支持的最大Query Batch Size | 否 | 50 | 如果出现out of memeory 错误,尝试减少该数值 |
|
||||||
|
| BLOCK_RATIO | float | | 否 | 0.75 | 建议配置 输入平均Token数/(输入+输出平均Token数) |
|
||||||
|
| MAX_CACHED_TASK_NUM | int | 服务缓存队列最大长度 | 否 | 128 | 队列达到上限后,会拒绝新的请求 |
|
||||||
|
| PUSH_MODE_HTTP_WORKERS | int | HTTP服务进程数 | 否 | 1 | 在 PUSH_MODE_HTTP_PORT 配置的情况下有效,高并发下提高该数值,建议最高配置为8 |
|
||||||
|
| USE_WARMUP | int | 是否进行 warmup | 否 | 0 | |
|
||||||
|
| USE_HF_TOKENIZER | int | 是否进行使用huggingface的词表 | 否 | 0 | |
|
||||||
|
| USE_CACHE_KV_INT8 | int | 是否将INT8配置为KV Cache的类型 | 否 | 0 | c8量化模型需要配置为1 |
|
||||||
|
| MODEL_DIR | str | 模型文件路径 | 否 | /models/ | |
|
||||||
|
| FD_MODEL_CONFIG_PATH | str | 模型config文件路径 | 否 | ${model_dir}/config.json | |
|
||||||
|
| DISTRIBUTED_CONFIG | str | 模型分布式配置文件路径 | 否 | ${model_dir}/rank_mapping.csv | |
|
||||||
|
|
||||||
|
## 请求参数介绍
|
||||||
|
|
||||||
| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
|
| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
|
||||||
| :---: | :-----: | :---: | :---: | :-----: | :----: |
|
| :---: | :-----: | :---: | :---: | :-----: | :----: |
|
||||||
@@ -195,19 +238,8 @@ for line in res.iter_lines():
|
|||||||
| stream | bool | 是否流式返回 | 否 | False | |
|
| stream | bool | 是否流式返回 | 否 | False | |
|
||||||
| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
|
| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
|
||||||
| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
|
| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
|
||||||
|
| return_usage | bool | 是否返回输入、输出 token 数量 | 否 | False | |
|
||||||
|
|
||||||
* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
|
* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
|
||||||
* stream 参数仅对 HTTP 请求生效
|
* stream 参数仅对 HTTP 请求生效
|
||||||
* return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
|
* return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
|
||||||
|
|
||||||
### 返回示例
|
|
||||||
|
|
||||||
```
|
|
||||||
如果stream为True,流式返回
|
|
||||||
如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
|
||||||
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
|
||||||
|
|
||||||
如果stream为False,非流式返回
|
|
||||||
如果正常,返回{'tokens_all': xxx, 'tokens_all_num': xxx, ..., 'error_msg': '', 'error_code': 0}
|
|
||||||
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
|
||||||
```
|
|
||||||
|
@@ -20,3 +20,4 @@ pynvml
|
|||||||
|
|
||||||
# paddlenlp
|
# paddlenlp
|
||||||
tiktoken
|
tiktoken
|
||||||
|
transformers
|
||||||
|
@@ -14,8 +14,8 @@ export FLAGS_gemm_use_half_precision_compute_type=0
|
|||||||
export NVIDIA_TF32_OVERRIDE=0
|
export NVIDIA_TF32_OVERRIDE=0
|
||||||
|
|
||||||
# Model hyperparameters
|
# Model hyperparameters
|
||||||
export MP_NUM=${MP_NUM:-"1"} # GPU num
|
export MP_NUM=${MP_NUM:-"1"} # Number of GPUs
|
||||||
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} # GPU
|
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} # GPU ids
|
||||||
export MAX_SEQ_LEN=${MAX_SEQ_LEN:-"8192"}
|
export MAX_SEQ_LEN=${MAX_SEQ_LEN:-"8192"}
|
||||||
export MAX_DEC_LEN=${MAX_DEC_LEN:-"2048"}
|
export MAX_DEC_LEN=${MAX_DEC_LEN:-"2048"}
|
||||||
export BATCH_SIZE=${BATCH_SIZE:-"20"}
|
export BATCH_SIZE=${BATCH_SIZE:-"20"}
|
||||||
@@ -44,7 +44,6 @@ mkdir -p log
|
|||||||
rm -rf console.log log/*
|
rm -rf console.log log/*
|
||||||
rm -rf /dev/shm/*
|
rm -rf /dev/shm/*
|
||||||
|
|
||||||
# 启动服务
|
|
||||||
echo "start serving ..."
|
echo "start serving ..."
|
||||||
|
|
||||||
tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \
|
tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \
|
||||||
@@ -55,4 +54,5 @@ tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-mem
|
|||||||
--grpc-port=${GRPC_PORT} \
|
--grpc-port=${GRPC_PORT} \
|
||||||
--metrics-port=${METRICS_PORT} \
|
--metrics-port=${METRICS_PORT} \
|
||||||
--log-file log/server.log --log-info true > log/console.log 2>&1 &
|
--log-file log/server.log --log-info true > log/console.log 2>&1 &
|
||||||
echo "模型服务的启动日志,请查看" ${PWD}"/log/server.log 和 "${PWD}"/log/workerlog.0 "
|
|
||||||
|
echo "The logs for the model service, please check" ${PWD}"/log/server.log and "${PWD}"/log/workerlog.0"
|
||||||
|
@@ -3,7 +3,7 @@
|
|||||||
pids=($(ps aux | grep -E 'tritonserver' | grep -v grep | awk '{print $2}'))
|
pids=($(ps aux | grep -E 'tritonserver' | grep -v grep | awk '{print $2}'))
|
||||||
|
|
||||||
if [ ${#pids[@]} -eq 0 ]; then
|
if [ ${#pids[@]} -eq 0 ]; then
|
||||||
echo "未找到 tritonserver 相关进程"
|
echo "Can not find tritonserver."
|
||||||
timeout=1
|
timeout=1
|
||||||
else
|
else
|
||||||
timeout=300
|
timeout=300
|
||||||
@@ -11,7 +11,7 @@ fi
|
|||||||
|
|
||||||
# kill processor
|
# kill processor
|
||||||
for pid in "${pids[@]}"; do
|
for pid in "${pids[@]}"; do
|
||||||
echo "正在中断进程 $pid"
|
echo "killing $pid"
|
||||||
kill -2 "$pid"
|
kill -2 "$pid"
|
||||||
done
|
done
|
||||||
|
|
||||||
@@ -29,9 +29,8 @@ while : ; do
|
|||||||
elapsed_time=$((current_time - start_time))
|
elapsed_time=$((current_time - start_time))
|
||||||
|
|
||||||
if [ $elapsed_time -ge $timeout ]; then
|
if [ $elapsed_time -ge $timeout ]; then
|
||||||
echo "tritonserver进程超时未退出"
|
echo "forcibly kill all process ..."
|
||||||
echo "强制杀死所有有关进程"
|
pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|infer|multiprocessing.resource_tracker|paddle.distributed.launch|task_queue_manager|app.py|spawn_main" | grep -v grep | grep -v start_both | awk '{print $2}');
|
||||||
pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|new_infer.py|infer|multiprocessing.resource_tracker|paddle.distributed.launch|task_queue_manager|app.py|memory_log.py|spawn_main" | grep -v grep | grep -v start_both | awk '{print $2}');
|
|
||||||
echo $pids;
|
echo $pids;
|
||||||
for pid in ${pids[@]}; do
|
for pid in ${pids[@]}; do
|
||||||
kill -9 ${pid}
|
kill -9 ${pid}
|
||||||
@@ -39,14 +38,14 @@ while : ; do
|
|||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
|
|
||||||
pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|new_infer.py|multiprocessing.resource_tracker|paddle.distributed.launch|app.py|memory_log.py|spawn_main" | grep -v grep | awk '{print $2}');
|
pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|multiprocessing.resource_tracker|paddle.distributed.launch|app.py|spawn_main" | grep -v grep | awk '{print $2}');
|
||||||
array=($(echo "$pids" | tr ' ' '\n'))
|
array=($(echo "$pids" | tr ' ' '\n'))
|
||||||
|
|
||||||
if [ ${#array[*]} -ne 0 ]; then
|
if [ ${#array[*]} -ne 0 ]; then
|
||||||
echo "进程还没有清理干净, 等待清理完毕"
|
echo "cleaning process, please wait ..."
|
||||||
sleep 1
|
sleep 1
|
||||||
else
|
else
|
||||||
echo "进程已经清理干净"
|
echo "clean finished."
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
@@ -65,5 +64,5 @@ for in_pid in ${health_checker_pids[@]}; do
|
|||||||
done
|
done
|
||||||
echo 'end kill health checker'
|
echo 'end kill health checker'
|
||||||
|
|
||||||
echo "所有进程已终止"
|
echo "all process terminated."
|
||||||
exit 0
|
exit 0
|
||||||
|
@@ -15,19 +15,16 @@
|
|||||||
|
|
||||||
def check_basic_params(req_dict):
|
def check_basic_params(req_dict):
|
||||||
"""
|
"""
|
||||||
对单个输入请求进行基础的校验检查,适用于推拉模式。
|
checks input requests for basic parameters
|
||||||
对输入的全部字段进行检查,统一将报错信息发送给用户,注意同一个字段的检查逻辑是独立的,避免重复的报错信息。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
req_dict (dict): 请求的字典格式数据,包含文本、模型、序列长度、最大token数等字段。
|
req_dict (dict): request parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[str]: 如果校验有错误,返回错误信息列表,如果校验正确,返回空列表。
|
list[str]: if error, return a list of error messages; return an empty list otherwise
|
||||||
"""
|
"""
|
||||||
|
|
||||||
error_msg = []
|
error_msg = []
|
||||||
|
|
||||||
# text、input_ids和messages必须设置一个
|
|
||||||
bools = ("text" in req_dict, "input_ids" in req_dict, "messages" in req_dict)
|
bools = ("text" in req_dict, "input_ids" in req_dict, "messages" in req_dict)
|
||||||
if sum(bools) == 0:
|
if sum(bools) == 0:
|
||||||
error_msg.append("The input parameters should contain either `text`, `input_ids` or `messages`")
|
error_msg.append("The input parameters should contain either `text`, `input_ids` or `messages`")
|
||||||
@@ -55,7 +52,6 @@ def check_basic_params(req_dict):
|
|||||||
(not isinstance(req_dict["min_dec_len"], int) or req_dict["min_dec_len"] < 1):
|
(not isinstance(req_dict["min_dec_len"], int) or req_dict["min_dec_len"] < 1):
|
||||||
error_msg.append("The `min_dec_len` must be an integer and greater than 0")
|
error_msg.append("The `min_dec_len` must be an integer and greater than 0")
|
||||||
|
|
||||||
# 如果设置了seq_len和max_tokens,最终都赋值给max_dec_len
|
|
||||||
keys = ("max_dec_len", "seq_len", "max_tokens")
|
keys = ("max_dec_len", "seq_len", "max_tokens")
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key in req_dict and (not isinstance(req_dict[key], int) or req_dict[key] < 1):
|
if key in req_dict and (not isinstance(req_dict[key], int) or req_dict[key] < 1):
|
||||||
@@ -65,7 +61,6 @@ def check_basic_params(req_dict):
|
|||||||
if "max_tokens" in req_dict and "max_dec_len" not in req_dict:
|
if "max_tokens" in req_dict and "max_dec_len" not in req_dict:
|
||||||
req_dict["max_dec_len"] = req_dict["max_tokens"]
|
req_dict["max_dec_len"] = req_dict["max_tokens"]
|
||||||
|
|
||||||
# 简化处理,topp和top_p只允许有一个,且最终都赋值给topp
|
|
||||||
keys = ("topp", "top_p")
|
keys = ("topp", "top_p")
|
||||||
if sum([key in req_dict for key in keys]) > 1:
|
if sum([key in req_dict for key in keys]) > 1:
|
||||||
error_msg.append(f"Only one of {keys} should be set")
|
error_msg.append(f"Only one of {keys} should be set")
|
||||||
@@ -89,7 +84,6 @@ def check_basic_params(req_dict):
|
|||||||
elif len(req_dict["eos_token_ids"]) > 1:
|
elif len(req_dict["eos_token_ids"]) > 1:
|
||||||
error_msg.append("The length of `eos_token_ids` must be 1 if you set it")
|
error_msg.append("The length of `eos_token_ids` must be 1 if you set it")
|
||||||
|
|
||||||
# 简化处理,infer_seed和seed只允许有一个,且最终都赋值给infer_seed
|
|
||||||
keys = ("infer_seed", "seed")
|
keys = ("infer_seed", "seed")
|
||||||
if sum([key in req_dict for key in keys]) > 1:
|
if sum([key in req_dict for key in keys]) > 1:
|
||||||
error_msg.append(f"Only one of {keys} should be set")
|
error_msg.append(f"Only one of {keys} should be set")
|
||||||
@@ -103,15 +97,18 @@ def check_basic_params(req_dict):
|
|||||||
if "response_type" in req_dict and (req_dict["response_type"].lower() not in ("fastdeploy", "openai")):
|
if "response_type" in req_dict and (req_dict["response_type"].lower() not in ("fastdeploy", "openai")):
|
||||||
error_msg.append("The `response_type` must be either `fastdeploy` or `openai`.")
|
error_msg.append("The `response_type` must be either `fastdeploy` or `openai`.")
|
||||||
|
|
||||||
# 返回信息
|
|
||||||
return error_msg
|
return error_msg
|
||||||
|
|
||||||
|
|
||||||
def add_default_params(req_dict):
|
def add_default_params(req_dict):
|
||||||
"""
|
"""
|
||||||
给req_dict字典添加默认值。
|
add default params to req_dict
|
||||||
注意:虽然infer.py中设置请求参数有默认值,但为了统一,这里提前设置默认值。请保证此处默认值和infer.py中一致。
|
|
||||||
返回添加默认值后的req_dict字典。
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
req_dict (dict): input dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: req_dict with default params
|
||||||
"""
|
"""
|
||||||
assert isinstance(req_dict, dict), "The `req_dict` must be a dict."
|
assert isinstance(req_dict, dict), "The `req_dict` must be a dict."
|
||||||
if "min_dec_len" not in req_dict:
|
if "min_dec_len" not in req_dict:
|
||||||
|
@@ -14,19 +14,15 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from paddlenlp.utils.llm_utils import get_eos_token_id
|
|
||||||
from paddlenlp.transformers import (
|
|
||||||
LlamaTokenizer,
|
|
||||||
Llama3Tokenizer,
|
|
||||||
AutoTokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
from server.utils import data_processor_logger
|
from paddlenlp.transformers import Llama3Tokenizer, LlamaTokenizer
|
||||||
|
from paddlenlp.utils.llm_utils import get_eos_token_id
|
||||||
from server.engine.config import Config
|
from server.engine.config import Config
|
||||||
|
from server.utils import data_processor_logger
|
||||||
|
|
||||||
|
|
||||||
class BaseDataProcessor(ABC):
|
class BaseDataProcessor(ABC):
|
||||||
"""Data processor的基类"""
|
"""base class for data processor"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
"""
|
||||||
@@ -75,59 +71,54 @@ class BaseDataProcessor(ABC):
|
|||||||
|
|
||||||
def text2ids(self, text):
|
def text2ids(self, text):
|
||||||
"""
|
"""
|
||||||
将文本转换为对应的 ID
|
text to token ids
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): 待转换的文本。
|
text (str): text
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 转换后的 ID 列表。
|
List[int]: token ids list
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def messages2ids(self, messages):
|
def messages2ids(self, messages):
|
||||||
"""
|
"""
|
||||||
将多轮对话转换为对话ID序列。
|
Convert multi-turn messages into ID sequences.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[List[Dict[str, Any]]]): 对话列表,每个对话是一个字典。
|
messages (List[List[Dict[str, Any]]]): multi-turn messages.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 对话ID序列,每个ID是一个整数。
|
List[int]: ID sequences
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def ids2tokens(self, token_ids, task_id=None):
|
def ids2tokens(self, token_ids, task_id=None):
|
||||||
"""
|
"""
|
||||||
将 token ids 解码为字符串
|
token ids to strings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids (List[int]): 包含 token ids 的列表
|
token_ids (List[int]): token ids
|
||||||
task_id (str): 当前task_ids对应的任务ID
|
task_id (str): task id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: 解码后的 tokenized 字符串列表
|
List[str]: strings
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
"""
|
"""
|
||||||
加载分词器。
|
load tokenizer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tokenizer: 分词器。
|
tokenizer (AutoTokenizer)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class DataProcessor(BaseDataProcessor):
|
class DataProcessor(BaseDataProcessor):
|
||||||
"""继承自Data processor的基类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""
|
|
||||||
初始化函数。
|
|
||||||
"""
|
|
||||||
self.config = Config()
|
self.config = Config()
|
||||||
max_length = self.config.get_model_config().get('max_length', 1024)
|
max_length = self.config.get_model_config().get('max_length', 1024)
|
||||||
self.src_length = max_length - self.config.seq_len_limit
|
self.src_length = max_length - self.config.seq_len_limit
|
||||||
@@ -182,6 +173,7 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
token_ids = response_dict.get("token_ids", [])
|
token_ids = response_dict.get("token_ids", [])
|
||||||
response_dict["token"] = self.ids2tokens(token_ids, response_dict["req_id"])
|
response_dict["token"] = self.ids2tokens(token_ids, response_dict["req_id"])
|
||||||
|
response_dict["usage"] = {"completion_tokens" : response_dict["send_idx"] + 1}
|
||||||
|
|
||||||
if is_end:
|
if is_end:
|
||||||
response_dict["tokens_all"] = self.clear_request_status(req_id)
|
response_dict["tokens_all"] = self.clear_request_status(req_id)
|
||||||
@@ -189,8 +181,22 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
def text2ids(self, text):
|
def text2ids(self, text):
|
||||||
"""
|
"""
|
||||||
text to ids
|
text to token ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: token ids list
|
||||||
"""
|
"""
|
||||||
|
if self.config.use_hf_tokenizer:
|
||||||
|
tokens = self.tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors="np",
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if self.tokenizer.chat_template is not None:
|
if self.tokenizer.chat_template is not None:
|
||||||
text = [text] if isinstance(text, str) else text
|
text = [text] if isinstance(text, str) else text
|
||||||
text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in text]
|
text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in text]
|
||||||
@@ -207,23 +213,47 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
|
|
||||||
def messages2ids(self, messages):
|
def messages2ids(self, messages):
|
||||||
"""
|
"""
|
||||||
将多轮对话转换为对话ID序列。
|
Convert multi-turn messages into ID sequences.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages (List[List[Dict[str, Any]]]): 对话列表,每个对话是一个字典。
|
messages (List[List[Dict[str, Any]]]): multi-turn messages.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[int]: 对话ID序列,每个ID是一个整数。
|
List[int]: ID sequences
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def ids2tokens(self, token_id, task_id):
|
def ids2tokens(self, token_id, task_id):
|
||||||
"""
|
"""
|
||||||
ids to tokens
|
token ids to strings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids (List[int]): token ids
|
||||||
|
task_id (str): task id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: strings
|
||||||
"""
|
"""
|
||||||
|
if self.config.use_hf_tokenizer:
|
||||||
if task_id not in self.decode_status:
|
if task_id not in self.decode_status:
|
||||||
# 记录deocde的prefix offset & read offset & history token ids & history token strings
|
# history token ids & history token strings & befer decode str
|
||||||
|
self.decode_status[task_id] = [[], [], ""]
|
||||||
|
|
||||||
|
previous_token_ids = self.decode_status[task_id][0]
|
||||||
|
decode_str = self.tokenizer.batch_decode([previous_token_ids + token_id],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
clean_up_tokenization_spaces=False)
|
||||||
|
if isinstance(decode_str, list) and len(decode_str):
|
||||||
|
new_str = decode_str[0].replace(self.decode_status[task_id][2], "", 1)
|
||||||
|
self.decode_status[task_id][1].append(new_str)
|
||||||
|
self.decode_status[task_id][2] = decode_str[0]
|
||||||
|
else:
|
||||||
|
new_str = ""
|
||||||
|
self.decode_status[task_id][0] += token_id
|
||||||
|
return new_str
|
||||||
|
else:
|
||||||
|
if task_id not in self.decode_status:
|
||||||
|
# prefix offset & read offset & history token ids & history token strings
|
||||||
self.decode_status[task_id] = [0, 0, [], []]
|
self.decode_status[task_id] = [0, 0, [], []]
|
||||||
|
|
||||||
prefix_offset = self.decode_status[task_id][0]
|
prefix_offset = self.decode_status[task_id][0]
|
||||||
@@ -235,7 +265,6 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
self.decode_status[task_id][1] = read_offset
|
self.decode_status[task_id][1] = read_offset
|
||||||
self.decode_status[task_id][2] += token_id
|
self.decode_status[task_id][2] += token_id
|
||||||
self.decode_status[task_id][3].append(decode_str)
|
self.decode_status[task_id][3].append(decode_str)
|
||||||
# 此处为流式返回中的每个token字符串结果,可自行添加处理
|
|
||||||
return decode_str
|
return decode_str
|
||||||
|
|
||||||
def _load_tokenizer(self):
|
def _load_tokenizer(self):
|
||||||
@@ -245,14 +274,28 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
Returns:
|
Returns:
|
||||||
tokenizer (AutoTokenizer)
|
tokenizer (AutoTokenizer)
|
||||||
"""
|
"""
|
||||||
|
if self.config.use_hf_tokenizer:
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False)
|
||||||
|
else:
|
||||||
|
from paddlenlp.transformers import AutoTokenizer
|
||||||
return AutoTokenizer.from_pretrained(self.config.model_dir)
|
return AutoTokenizer.from_pretrained(self.config.model_dir)
|
||||||
|
|
||||||
def clear_request_status(self, task_id):
|
def clear_request_status(self, task_id):
|
||||||
"""
|
"""
|
||||||
clear request status
|
clear request status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (str): task id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
results_all (str): all token strings
|
||||||
"""
|
"""
|
||||||
results_all = ""
|
results_all = ""
|
||||||
if task_id in self.decode_status:
|
if task_id in self.decode_status:
|
||||||
|
if self.config.use_hf_tokenizer:
|
||||||
|
results_all = self.decode_status[task_id][2]
|
||||||
|
else:
|
||||||
results_all = "".join(self.decode_status[task_id][3])
|
results_all = "".join(self.decode_status[task_id][3])
|
||||||
del self.decode_status[task_id]
|
del self.decode_status[task_id]
|
||||||
return results_all
|
return results_all
|
||||||
@@ -260,18 +303,27 @@ class DataProcessor(BaseDataProcessor):
|
|||||||
def get_eos_tokens_lens(self):
|
def get_eos_tokens_lens(self):
|
||||||
"""
|
"""
|
||||||
get eos_token_id lens
|
get eos_token_id lens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: eos_token_id lens
|
||||||
"""
|
"""
|
||||||
return len(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
return len(get_eos_token_id(self.tokenizer, self.config.generation_config))
|
||||||
|
|
||||||
def get_eos_tokens(self):
|
def get_eos_tokens(self):
|
||||||
"""
|
"""
|
||||||
get all eos_token_id
|
get all eos_token_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[int]: eos_token_id list
|
||||||
"""
|
"""
|
||||||
return get_eos_token_id(self.tokenizer, self.config.generation_config)
|
return get_eos_token_id(self.tokenizer, self.config.generation_config)
|
||||||
|
|
||||||
def get_pad_id(self):
|
def get_pad_id(self):
|
||||||
"""
|
"""
|
||||||
get pad_token_id, if not pad_token_id, use eos_token
|
get pad_token_id, if not pad_token_id, use eos_token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: pad_token_id
|
||||||
"""
|
"""
|
||||||
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
|
||||||
return self.tokenizer.eos_token
|
return self.tokenizer.eos_token
|
||||||
|
@@ -16,14 +16,14 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from paddlenlp.generation import GenerationConfig
|
|
||||||
|
|
||||||
|
from paddlenlp.generation import GenerationConfig
|
||||||
from server.utils import model_server_logger
|
from server.utils import model_server_logger
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""
|
"""
|
||||||
初始化配置,各参数优先以环境变量配置的值为准
|
initial configuration
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -31,7 +31,7 @@ class Config:
|
|||||||
|
|
||||||
def read_from_env(self):
|
def read_from_env(self):
|
||||||
"""
|
"""
|
||||||
从环境变量中读取参数
|
get the configuration from environment
|
||||||
"""
|
"""
|
||||||
env = os.environ
|
env = os.environ
|
||||||
self.model_dir = env.get(
|
self.model_dir = env.get(
|
||||||
@@ -44,12 +44,12 @@ class Config:
|
|||||||
if env.get("FD_MODEL_CONFIG_PATH", None):
|
if env.get("FD_MODEL_CONFIG_PATH", None):
|
||||||
self.model_config_path = env.get("FD_MODEL_CONFIG_PATH")
|
self.model_config_path = env.get("FD_MODEL_CONFIG_PATH")
|
||||||
|
|
||||||
# 分布式配置文件
|
# distributed config
|
||||||
self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv")
|
self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv")
|
||||||
if os.getenv("DISTRIBUTED_CONFIG", None):
|
if os.getenv("DISTRIBUTED_CONFIG", None):
|
||||||
self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG")
|
self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG")
|
||||||
|
|
||||||
# 硬件配置信息
|
# device config
|
||||||
self.device = env.get("DEVICE", "GPU")
|
self.device = env.get("DEVICE", "GPU")
|
||||||
self.device_ids = ",".join([str(i) for i in range(self.mp_num)])
|
self.device_ids = ",".join([str(i) for i in range(self.mp_num)])
|
||||||
if self.device == "GPU":
|
if self.device == "GPU":
|
||||||
@@ -58,15 +58,15 @@ class Config:
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"unsupported device type: {self.device}")
|
raise Exception(f"unsupported device type: {self.device}")
|
||||||
|
|
||||||
# Triton服务层参数
|
# Triton config
|
||||||
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_BATCH", 1))
|
self.max_prefill_batch = int(os.getenv("MAX_PREFILL_BATCH", 1))
|
||||||
if self.max_prefill_batch <= 0:
|
if self.max_prefill_batch <= 0:
|
||||||
raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0")
|
raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0")
|
||||||
self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0))
|
self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0))
|
||||||
|
|
||||||
# 最大支持缓存的task数
|
# max cached task num
|
||||||
self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128"))
|
self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128"))
|
||||||
# 如果没有配置PUSH_MODE_HTTP_PORT, 则只支持 GRPC 服务模式
|
# if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled
|
||||||
self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1"))
|
self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1"))
|
||||||
if self.push_mode_http_port > 0:
|
if self.push_mode_http_port > 0:
|
||||||
grpc_port = os.getenv("GRPC_PORT", None)
|
grpc_port = os.getenv("GRPC_PORT", None)
|
||||||
@@ -74,25 +74,25 @@ class Config:
|
|||||||
raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0")
|
raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0")
|
||||||
self.grpc_port = int(grpc_port)
|
self.grpc_port = int(grpc_port)
|
||||||
|
|
||||||
# http服务线的worker数
|
# http worker num
|
||||||
self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1"))
|
self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1"))
|
||||||
if self.push_mode_http_workers < 1:
|
if self.push_mode_http_workers < 1:
|
||||||
raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive")
|
raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive")
|
||||||
|
|
||||||
# 导出Paddle代码版本,便于对比版本号
|
# Padlle commit id
|
||||||
import paddle
|
import paddle
|
||||||
self.paddle_commit_id = paddle.version.commit
|
self.paddle_commit_id = paddle.version.commit
|
||||||
|
|
||||||
# 探活时检测engine主循环是否正常的时间间隔
|
# time interval for detecting whether the engine loop is normal during probing
|
||||||
self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))
|
self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))
|
||||||
|
|
||||||
# 与模型相关信息(注意要与导出的模型保持一致,否则存在效果问题)
|
# model config
|
||||||
self.dtype = env.get("DTYPE", "bfloat16")
|
self.dtype = env.get("DTYPE", "bfloat16")
|
||||||
self.block_size = int(env.get("BLOCK_SIZE", 64))
|
self.block_size = int(env.get("BLOCK_SIZE", 64))
|
||||||
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
|
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
|
||||||
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
|
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
|
||||||
|
|
||||||
# 推理引擎配置
|
# infer config
|
||||||
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
|
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
|
||||||
self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
|
self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
|
||||||
self.max_dec_len = int(env.get("MAX_DEC_LEN", 1024))
|
self.max_dec_len = int(env.get("MAX_DEC_LEN", 1024))
|
||||||
@@ -102,14 +102,14 @@ class Config:
|
|||||||
self.bad_tokens = str(env.get("BAD_TOKENS", "-1"))
|
self.bad_tokens = str(env.get("BAD_TOKENS", "-1"))
|
||||||
self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1))
|
self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1))
|
||||||
|
|
||||||
# 引擎输入队列端口号
|
# infer queue port
|
||||||
self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666))
|
self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666))
|
||||||
|
|
||||||
# 是否开启探活服务
|
# whether to use custom health checker
|
||||||
self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
|
self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
|
||||||
|
|
||||||
# 环境变量配置MAX_SEQ_LEN,MAX_DEC_LEN将用于控制服务请求合法性检查
|
# Check the legality of requests
|
||||||
self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 7168))
|
self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 8192))
|
||||||
self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024))
|
self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024))
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -118,7 +118,10 @@ class Config:
|
|||||||
# uuid
|
# uuid
|
||||||
self.shm_uuid = os.getenv("SHM_UUID", '')
|
self.shm_uuid = os.getenv("SHM_UUID", '')
|
||||||
|
|
||||||
# 加载 Generation 文件
|
# use huggingface tokenizer
|
||||||
|
self.use_hf_tokenizer = int(os.getenv("USE_HF_TOKENIZER", 0)) == 1
|
||||||
|
|
||||||
|
# Generation config
|
||||||
try:
|
try:
|
||||||
self.generation_config = GenerationConfig.from_pretrained(self.model_dir)
|
self.generation_config = GenerationConfig.from_pretrained(self.model_dir)
|
||||||
except:
|
except:
|
||||||
@@ -133,7 +136,7 @@ class Config:
|
|||||||
|
|
||||||
def postprocess(self):
|
def postprocess(self):
|
||||||
"""
|
"""
|
||||||
根据配置参数,计算部分额外的参数
|
calculate some parameters
|
||||||
"""
|
"""
|
||||||
if self.block_ratio >= 1.0:
|
if self.block_ratio >= 1.0:
|
||||||
self.enc_dec_block_num = (self.max_dec_len + self.block_size - 1) // self.block_size
|
self.enc_dec_block_num = (self.max_dec_len + self.block_size - 1) // self.block_size
|
||||||
@@ -148,7 +151,7 @@ class Config:
|
|||||||
|
|
||||||
def check(self):
|
def check(self):
|
||||||
"""
|
"""
|
||||||
检查参数配置合法性
|
check the legality of config
|
||||||
"""
|
"""
|
||||||
assert self.max_batch_size <= 256, (
|
assert self.max_batch_size <= 256, (
|
||||||
"The parameter `max_batch_size` is not allowed to exceed 256, "
|
"The parameter `max_batch_size` is not allowed to exceed 256, "
|
||||||
@@ -167,10 +170,10 @@ class Config:
|
|||||||
|
|
||||||
def print(self, file=None):
|
def print(self, file=None):
|
||||||
"""
|
"""
|
||||||
输出所有参数配置
|
print all config
|
||||||
|
|
||||||
file: 如若指定file路径,同时将日志以追加方式写入到另外的文件中
|
Args:
|
||||||
解决当前日志系统仅保留7天,无法追查启动信息问题
|
file (str): the path of file to save config
|
||||||
"""
|
"""
|
||||||
model_server_logger.info(
|
model_server_logger.info(
|
||||||
"=================== Configuration Information ===============")
|
"=================== Configuration Information ===============")
|
||||||
@@ -192,14 +195,17 @@ class Config:
|
|||||||
|
|
||||||
def get_model_config(self):
|
def get_model_config(self):
|
||||||
"""
|
"""
|
||||||
读取模型配置文件
|
load config file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: the config file
|
||||||
"""
|
"""
|
||||||
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
|
model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
|
||||||
return model_config_json
|
return model_config_json
|
||||||
|
|
||||||
def read_from_config(self):
|
def read_from_config(self):
|
||||||
"""
|
"""
|
||||||
从配置文件中读取参数
|
reset model config from json file
|
||||||
"""
|
"""
|
||||||
from server.utils import get_logger
|
from server.utils import get_logger
|
||||||
logger = get_logger("model_server", "infer_config.log")
|
logger = get_logger("model_server", "infer_config.log")
|
||||||
@@ -218,6 +224,12 @@ class Config:
|
|||||||
assert self.dec_len_limit <= self.max_seq_len, f"The loading model requires MAX_DEC_LEN <= {self.max_seq_len}, but now the setting MAX_DEC_LEN={self.dec_len_limit}."
|
assert self.dec_len_limit <= self.max_seq_len, f"The loading model requires MAX_DEC_LEN <= {self.max_seq_len}, but now the setting MAX_DEC_LEN={self.dec_len_limit}."
|
||||||
|
|
||||||
def get_unique_name(self, name):
|
def get_unique_name(self, name):
|
||||||
|
"""
|
||||||
|
get unique name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): the name add uuid
|
||||||
|
"""
|
||||||
return name + f"_{self.shm_uuid}"
|
return name + f"_{self.shm_uuid}"
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
@@ -12,29 +12,27 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import weakref
|
import weakref
|
||||||
import multiprocessing
|
|
||||||
import numpy as np
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
|
|
||||||
from server.engine.task_queue_manager import (
|
import numpy as np
|
||||||
TaskQueueManager,
|
|
||||||
launch_queue_service,
|
|
||||||
)
|
|
||||||
from server.engine.resource_manager import ResourceManager
|
from server.engine.resource_manager import ResourceManager
|
||||||
|
from server.engine.task_queue_manager import (TaskQueueManager,
|
||||||
|
launch_queue_service)
|
||||||
from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
|
from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
|
||||||
from server.utils import model_server_logger
|
from server.utils import model_server_logger
|
||||||
|
|
||||||
|
|
||||||
class Engine(object):
|
class Engine(object):
|
||||||
"""
|
"""
|
||||||
底层推理引擎,维护队列用于引擎使用
|
Engine Class
|
||||||
"""
|
"""
|
||||||
def __init__(self, cfg, token_processor):
|
def __init__(self, cfg, token_processor):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
@@ -44,31 +42,25 @@ class Engine(object):
|
|||||||
self.is_started = False
|
self.is_started = False
|
||||||
|
|
||||||
self._init_engine_flags()
|
self._init_engine_flags()
|
||||||
# 此处函数可考虑是否注释,添加后,如果引擎结束
|
|
||||||
# 会自动结束队列进程和推理infer进程
|
|
||||||
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
self._finalizer = weakref.finalize(self, self._exit_sub_services)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""
|
"""
|
||||||
初始化引擎所需的各进程
|
initialize engine and start sub services
|
||||||
"""
|
"""
|
||||||
assert not self.is_started, "The engine is already started.!"
|
assert not self.is_started, "The engine is already started.!"
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# 启动队列进程(服务层与引擎层通信)服务
|
|
||||||
self.queue_service = self._start_tasks_queue_service()
|
self.queue_service = self._start_tasks_queue_service()
|
||||||
self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
|
self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
|
||||||
|
|
||||||
# 由于BeamSearch在后处理时依赖queue与infer.py进行通信
|
|
||||||
# 此处将tasks_queue共享给TokenProcessor
|
|
||||||
self.token_processor.tasks_queue = self.tasks_queue
|
self.token_processor.tasks_queue = self.tasks_queue
|
||||||
|
|
||||||
self.infer_proc = self._start_infer_service()
|
self.infer_proc = self._start_infer_service()
|
||||||
model_server_logger.info("Waitting infer processes ready...")
|
model_server_logger.info("Waitting infer processes ready...")
|
||||||
while not self._infer_processes_ready():
|
while not self._infer_processes_ready():
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
self.is_started = True
|
self.is_started = True
|
||||||
|
|
||||||
# 启动warmup
|
# start warmup
|
||||||
if self.cfg.use_warmup:
|
if self.cfg.use_warmup:
|
||||||
model_server_logger.info("Start warmup")
|
model_server_logger.info("Start warmup")
|
||||||
self._set_warmup_token_processor()
|
self._set_warmup_token_processor()
|
||||||
@@ -76,19 +68,19 @@ class Engine(object):
|
|||||||
self._del_warmup_token_processor()
|
self._del_warmup_token_processor()
|
||||||
model_server_logger.info("Warmup finish")
|
model_server_logger.info("Warmup finish")
|
||||||
|
|
||||||
# 启动TokenProcessor子线程
|
# start TokenProcessor thread
|
||||||
self.token_processor.run()
|
self.token_processor.run()
|
||||||
model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time))
|
model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time))
|
||||||
|
|
||||||
def warmup(self):
|
def warmup(self):
|
||||||
"""
|
"""
|
||||||
通过构造测试数据进行推理,确保推理过程中不会出现OOM,能够正常进行推理
|
construct test tasks and avoid out of memory problem in the infer process
|
||||||
"""
|
"""
|
||||||
# 获取eos_token_id
|
# get eos_token_id
|
||||||
from server.data.processor import DataProcessor
|
from server.data.processor import DataProcessor
|
||||||
eos_token_ids = DataProcessor().get_eos_tokens()
|
eos_token_ids = DataProcessor().get_eos_tokens()
|
||||||
|
|
||||||
# 构造测试任务数据
|
# construct test tasks
|
||||||
res_task = []
|
res_task = []
|
||||||
for j in range(2 * self.cfg.max_batch_size):
|
for j in range(2 * self.cfg.max_batch_size):
|
||||||
data = {
|
data = {
|
||||||
@@ -109,19 +101,24 @@ class Engine(object):
|
|||||||
}
|
}
|
||||||
res_task.append(data)
|
res_task.append(data)
|
||||||
|
|
||||||
# 插入任务
|
|
||||||
for x in res_task:
|
for x in res_task:
|
||||||
while self.available_batch() == 0 or not self.insert_tasks([x]):
|
while self.available_batch() == 0 or not self.insert_tasks([x]):
|
||||||
time.sleep(0.0002)
|
time.sleep(0.0002)
|
||||||
|
|
||||||
self.token_processor._is_blocking = False
|
self.token_processor._is_blocking = False
|
||||||
# 等待所有数据推理结束
|
# wait for all tasks finished
|
||||||
while not self.all_tasks_finished():
|
while not self.all_tasks_finished():
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def insert_tasks(self, tasks):
|
def insert_tasks(self, tasks):
|
||||||
"""
|
"""
|
||||||
插入任务到引擎队列
|
insert tasks to the engine
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: list of tasks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if success, False otherwise
|
||||||
"""
|
"""
|
||||||
if not isinstance(tasks, list):
|
if not isinstance(tasks, list):
|
||||||
tasks = [tasks]
|
tasks = [tasks]
|
||||||
@@ -144,11 +141,13 @@ class Engine(object):
|
|||||||
tasks[i]["input_ids"] = tasks[i]["input_ids"][:self.cfg.max_seq_len - 1]
|
tasks[i]["input_ids"] = tasks[i]["input_ids"][:self.cfg.max_seq_len - 1]
|
||||||
if "seq_len" in tasks[i] and "max_dec_len" not in tasks[i]:
|
if "seq_len" in tasks[i] and "max_dec_len" not in tasks[i]:
|
||||||
tasks[i]["max_dec_len"] = tasks[i]["seq_len"]
|
tasks[i]["max_dec_len"] = tasks[i]["seq_len"]
|
||||||
|
|
||||||
# max_dec_len + input_token_num > MAX_SEQ_LEN
|
# max_dec_len + input_token_num > MAX_SEQ_LEN
|
||||||
if input_token_num + tasks[i]["max_dec_len"] > self.cfg.max_seq_len:
|
if input_token_num + tasks[i]["max_dec_len"] > self.cfg.max_seq_len:
|
||||||
tasks[i]["max_dec_len"] = self.cfg.max_seq_len - input_token_num
|
tasks[i]["max_dec_len"] = self.cfg.max_seq_len - input_token_num
|
||||||
model_server_logger.warning("Force max_dec_len to be {} for req_id={}.".format(
|
model_server_logger.warning("Force max_dec_len to be {} for req_id={}.".format(
|
||||||
tasks[i]["max_dec_len"], tasks[i]["req_id"]))
|
tasks[i]["max_dec_len"], tasks[i]["req_id"]))
|
||||||
|
|
||||||
# min_dec_len + input_token_num > MAX_SEQ_LEN
|
# min_dec_len + input_token_num > MAX_SEQ_LEN
|
||||||
if input_token_num + tasks[i]["min_dec_len"] > self.cfg.max_seq_len:
|
if input_token_num + tasks[i]["min_dec_len"] > self.cfg.max_seq_len:
|
||||||
tasks[i]["min_dec_len"] = self.cfg.max_seq_len - input_token_num
|
tasks[i]["min_dec_len"] = self.cfg.max_seq_len - input_token_num
|
||||||
@@ -170,67 +169,94 @@ class Engine(object):
|
|||||||
|
|
||||||
def task_is_finished(self, index):
|
def task_is_finished(self, index):
|
||||||
"""
|
"""
|
||||||
判断相应位置的任务是否完成
|
judge if the task is finished
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index: task index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if finished, False otherwise
|
||||||
"""
|
"""
|
||||||
assert index < len(self.resource_manager.stop_flags)
|
assert index < len(self.resource_manager.stop_flags)
|
||||||
return self.resource_manager.stop_flags[index]
|
return self.resource_manager.stop_flags[index]
|
||||||
|
|
||||||
def is_queue_empty(self):
|
def is_queue_empty(self):
|
||||||
"""
|
"""
|
||||||
判断引擎队列是否为空
|
judge if the queue is empty
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if empty, False otherwise
|
||||||
"""
|
"""
|
||||||
return self.tasks_queue.empty()
|
return self.tasks_queue.empty()
|
||||||
|
|
||||||
def is_resource_sufficient(self, input_token_num):
|
def is_resource_sufficient(self, input_token_num):
|
||||||
"""
|
"""
|
||||||
根据输入的token id长度,判断引擎资源是否充足
|
judge if the resource is sufficient
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_token_num: input token number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if sufficient, False otherwise
|
||||||
"""
|
"""
|
||||||
return self.resource_manager.is_resource_sufficient(input_token_num)
|
return self.resource_manager.is_resource_sufficient(input_token_num)
|
||||||
|
|
||||||
def all_tasks_finished(self):
|
def all_tasks_finished(self):
|
||||||
"""
|
"""
|
||||||
判断是否所有的引擎正在计算的任务已完成
|
judge if all tasks are finished
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if all finished, False otherwise
|
||||||
"""
|
"""
|
||||||
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
|
return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
|
||||||
|
|
||||||
def available_batch(self):
|
def available_batch(self):
|
||||||
"""
|
"""
|
||||||
引擎当前可用的最大Batch
|
available batch size of the engine
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: available batch size
|
||||||
"""
|
"""
|
||||||
return self.resource_manager.available_batch()
|
return self.resource_manager.available_batch()
|
||||||
|
|
||||||
def available_block_num(self):
|
def available_block_num(self):
|
||||||
"""
|
"""
|
||||||
引擎当前可用的block数量
|
available block number of the engine
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: available block number
|
||||||
"""
|
"""
|
||||||
return self.resource_manager.availabel_block_num()
|
return self.resource_manager.availabel_block_num()
|
||||||
|
|
||||||
def _set_warmup_token_processor(self):
|
def _set_warmup_token_processor(self):
|
||||||
"""
|
"""
|
||||||
设置token_processor,用于warmup阶段
|
set token_processor for warmup
|
||||||
"""
|
"""
|
||||||
self.token_processor_backup = self.token_processor
|
self.token_processor_backup = self.token_processor
|
||||||
self.token_processor = WarmUpTokenProcessor(self.cfg)
|
self.token_processor = WarmUpTokenProcessor(self.cfg)
|
||||||
# 设置resource_manager
|
|
||||||
self.token_processor.set_resource_manager(self.resource_manager)
|
self.token_processor.set_resource_manager(self.resource_manager)
|
||||||
self.token_processor.tasks_queue = self.tasks_queue
|
self.token_processor.tasks_queue = self.tasks_queue
|
||||||
# 启动TokenProcessor子线程
|
|
||||||
|
# start TokenProcessor thread
|
||||||
self.token_processor.run()
|
self.token_processor.run()
|
||||||
|
|
||||||
def _del_warmup_token_processor(self):
|
def _del_warmup_token_processor(self):
|
||||||
"""
|
"""
|
||||||
删除token_processor,用于正常推理阶段
|
delete token_processor for warmup
|
||||||
"""
|
"""
|
||||||
# 停止worker 线程
|
|
||||||
self.token_processor.stop()
|
self.token_processor.stop()
|
||||||
del self.token_processor
|
del self.token_processor
|
||||||
# 恢复token_processor
|
|
||||||
|
# reset token_processor
|
||||||
self.token_processor = self.token_processor_backup
|
self.token_processor = self.token_processor_backup
|
||||||
del self.token_processor_backup
|
del self.token_processor_backup
|
||||||
|
|
||||||
def _infer_processes_ready(self):
|
def _infer_processes_ready(self):
|
||||||
"""
|
"""
|
||||||
判断引擎是否初始化完成
|
judge if all infer processes are ready
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
return: True if all ready, False otherwise
|
||||||
"""
|
"""
|
||||||
if np.sum(self.flag_ready_array) == self.cfg.mp_num:
|
if np.sum(self.flag_ready_array) == self.cfg.mp_num:
|
||||||
return True
|
return True
|
||||||
@@ -238,7 +264,7 @@ class Engine(object):
|
|||||||
|
|
||||||
def _clear_engine_flags(self):
|
def _clear_engine_flags(self):
|
||||||
"""
|
"""
|
||||||
清除共享内存
|
clear engine flags
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.shm_flag_ready.close()
|
self.shm_flag_ready.close()
|
||||||
@@ -250,9 +276,8 @@ class Engine(object):
|
|||||||
|
|
||||||
def _init_engine_flags(self):
|
def _init_engine_flags(self):
|
||||||
"""
|
"""
|
||||||
初始化各共享内存,用于指示引擎状态
|
Initialize shared memory to indicate engine status
|
||||||
"""
|
"""
|
||||||
# 标记是否启动
|
|
||||||
flag_array = np.zeros([self.cfg.mp_num], dtype=np.int32)
|
flag_array = np.zeros([self.cfg.mp_num], dtype=np.int32)
|
||||||
try:
|
try:
|
||||||
tmp = shared_memory.SharedMemory(
|
tmp = shared_memory.SharedMemory(
|
||||||
@@ -270,7 +295,7 @@ class Engine(object):
|
|||||||
)
|
)
|
||||||
self.flag_ready_array[:] = 0
|
self.flag_ready_array[:] = 0
|
||||||
|
|
||||||
# 广播读取数据
|
# broadcast flag for engine
|
||||||
broadcast_flag_array = np.zeros([1], dtype=np.int32)
|
broadcast_flag_array = np.zeros([1], dtype=np.int32)
|
||||||
try:
|
try:
|
||||||
tmp = shared_memory.SharedMemory(
|
tmp = shared_memory.SharedMemory(
|
||||||
@@ -292,7 +317,6 @@ class Engine(object):
|
|||||||
)
|
)
|
||||||
self.flag_broadcast_array[0] = 0
|
self.flag_broadcast_array[0] = 0
|
||||||
|
|
||||||
# 标记引擎是否有调度出去的query
|
|
||||||
has_block_step_flag_array = np.zeros([1], dtype=np.int32)
|
has_block_step_flag_array = np.zeros([1], dtype=np.int32)
|
||||||
try:
|
try:
|
||||||
tmp = shared_memory.SharedMemory(
|
tmp = shared_memory.SharedMemory(
|
||||||
@@ -314,6 +338,9 @@ class Engine(object):
|
|||||||
self.flag_has_block_step_array[:] = 0
|
self.flag_has_block_step_array[:] = 0
|
||||||
|
|
||||||
def _exit_sub_services(self):
|
def _exit_sub_services(self):
|
||||||
|
"""
|
||||||
|
exit sub services
|
||||||
|
"""
|
||||||
if hasattr(self, "queue_service") and self.queue_service is not None:
|
if hasattr(self, "queue_service") and self.queue_service is not None:
|
||||||
self.queue_service.terminate()
|
self.queue_service.terminate()
|
||||||
self.queue_service.join()
|
self.queue_service.join()
|
||||||
@@ -321,6 +348,12 @@ class Engine(object):
|
|||||||
os.killpg(self.infer_proc.pid, signal.SIGTERM)
|
os.killpg(self.infer_proc.pid, signal.SIGTERM)
|
||||||
|
|
||||||
def _start_tasks_queue_service(self):
|
def _start_tasks_queue_service(self):
|
||||||
|
"""
|
||||||
|
start tasks queue service
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p: process handle
|
||||||
|
"""
|
||||||
p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num))
|
p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num))
|
||||||
p.start()
|
p.start()
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
@@ -335,7 +368,10 @@ class Engine(object):
|
|||||||
|
|
||||||
def _start_gpu_infer_service(self):
|
def _start_gpu_infer_service(self):
|
||||||
"""
|
"""
|
||||||
GPU模型推理进程启动
|
start gpu infer service
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
p: process handle
|
||||||
"""
|
"""
|
||||||
current_file_path = os.path.abspath(__file__)
|
current_file_path = os.path.abspath(__file__)
|
||||||
current_dir_path = os.path.split(current_file_path)[0]
|
current_dir_path = os.path.split(current_file_path)[0]
|
||||||
@@ -360,6 +396,6 @@ class Engine(object):
|
|||||||
|
|
||||||
def _start_infer_service(self):
|
def _start_infer_service(self):
|
||||||
"""
|
"""
|
||||||
启动模型推理进程
|
start infer service
|
||||||
"""
|
"""
|
||||||
return self._start_gpu_infer_service()
|
return self._start_gpu_infer_service()
|
||||||
|
@@ -18,20 +18,19 @@ import json
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
from multiprocessing import shared_memory
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from multiprocessing import shared_memory
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.distributed as dist
|
import paddle.distributed as dist
|
||||||
import paddle.distributed.fleet as fleet
|
import paddle.distributed.fleet as fleet
|
||||||
from paddlenlp_ops import step_paddle
|
|
||||||
from paddlenlp.utils.llm_utils import get_rotary_position_embedding
|
from paddlenlp.utils.llm_utils import get_rotary_position_embedding
|
||||||
|
from paddlenlp_ops import step_paddle
|
||||||
from server.utils import get_logger
|
|
||||||
from server.engine.config import Config
|
|
||||||
from task_queue_manager import TaskQueueManager
|
|
||||||
from server.data.processor import DataProcessor
|
from server.data.processor import DataProcessor
|
||||||
|
from server.engine.config import Config
|
||||||
|
from server.utils import get_logger
|
||||||
|
from task_queue_manager import TaskQueueManager
|
||||||
|
|
||||||
File_Path = os.path.realpath(sys.argv[0])
|
File_Path = os.path.realpath(sys.argv[0])
|
||||||
Dir_Path = os.path.dirname(File_Path)
|
Dir_Path = os.path.dirname(File_Path)
|
||||||
@@ -42,7 +41,8 @@ class ModelRunner:
|
|||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
self.MAX_INFER_SEED = 9223372036854775806 # 2**63 - 1
|
# 2**63 - 1
|
||||||
|
self.MAX_INFER_SEED = 9223372036854775806
|
||||||
|
|
||||||
self.config = Config()
|
self.config = Config()
|
||||||
self.model_cfg = self.config.get_model_config()
|
self.model_cfg = self.config.get_model_config()
|
||||||
@@ -77,12 +77,18 @@ class ModelRunner:
|
|||||||
|
|
||||||
def read_model_config(self):
|
def read_model_config(self):
|
||||||
"""
|
"""
|
||||||
读取通用模型配置文件
|
load model config file from json file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model_config_json: dict, model config file
|
||||||
"""
|
"""
|
||||||
model_config_json = json.load(open(self.config_file, 'r', encoding='utf-8'))
|
model_config_json = json.load(open(self.config_file, 'r', encoding='utf-8'))
|
||||||
return model_config_json
|
return model_config_json
|
||||||
|
|
||||||
def get_value(self, cfg, names):
|
def get_value(self, cfg, names):
|
||||||
|
"""
|
||||||
|
get value from config file by key names
|
||||||
|
"""
|
||||||
if not isinstance(names, list):
|
if not isinstance(names, list):
|
||||||
names = [names]
|
names = [names]
|
||||||
for name in names:
|
for name in names:
|
||||||
@@ -95,7 +101,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def format_print_configuration(self):
|
def format_print_configuration(self):
|
||||||
"""
|
"""
|
||||||
输出配置信息
|
print model config
|
||||||
"""
|
"""
|
||||||
logger.info("=============== Model Information ==============")
|
logger.info("=============== Model Information ==============")
|
||||||
for k, v in self.model_cfg.items():
|
for k, v in self.model_cfg.items():
|
||||||
@@ -106,6 +112,9 @@ class ModelRunner:
|
|||||||
logger.info("=====================================================\n")
|
logger.info("=====================================================\n")
|
||||||
|
|
||||||
def load_model_init_val(self):
|
def load_model_init_val(self):
|
||||||
|
"""
|
||||||
|
initialize model config from config file
|
||||||
|
"""
|
||||||
self.top_p = self.model_cfg.get("top_p", 0.0)
|
self.top_p = self.model_cfg.get("top_p", 0.0)
|
||||||
self.temperature = self.model_cfg.get("temperature", 1.0)
|
self.temperature = self.model_cfg.get("temperature", 1.0)
|
||||||
self.rope_theta = self.model_cfg.get('rope_theta', 10000.0)
|
self.rope_theta = self.model_cfg.get('rope_theta', 10000.0)
|
||||||
@@ -117,15 +126,14 @@ class ModelRunner:
|
|||||||
self.max_length = self.model_cfg.get('max_length', 1024)
|
self.max_length = self.model_cfg.get('max_length', 1024)
|
||||||
|
|
||||||
data_processor = DataProcessor()
|
data_processor = DataProcessor()
|
||||||
# 允许用户配置一个额外的 eos_token 长度
|
# reserve an eos token for request
|
||||||
self.eos_tokens_lens = data_processor.get_eos_tokens_lens() + 1
|
self.eos_tokens_lens = data_processor.get_eos_tokens_lens() + 1
|
||||||
self.pad_token_id = data_processor.get_pad_id()
|
self.pad_token_id = data_processor.get_pad_id()
|
||||||
|
|
||||||
def init_dist_env(self, seed=20):
|
def init_dist_env(self, seed=20):
|
||||||
"""
|
"""
|
||||||
初始化分布式环境
|
init distributed env
|
||||||
"""
|
"""
|
||||||
# start to init distributed env
|
|
||||||
strategy = fleet.DistributedStrategy()
|
strategy = fleet.DistributedStrategy()
|
||||||
|
|
||||||
strategy.hybrid_configs = {
|
strategy.hybrid_configs = {
|
||||||
@@ -140,7 +148,7 @@ class ModelRunner:
|
|||||||
fleet.init(is_collective=True, strategy=strategy)
|
fleet.init(is_collective=True, strategy=strategy)
|
||||||
|
|
||||||
def init_inputs(self):
|
def init_inputs(self):
|
||||||
# 初始化输入,所有输入都share进引擎
|
# init all inputs
|
||||||
if "num_key_value_heads" in self.model_cfg and \
|
if "num_key_value_heads" in self.model_cfg and \
|
||||||
self.model_cfg["num_key_value_heads"] is not None and \
|
self.model_cfg["num_key_value_heads"] is not None and \
|
||||||
int(self.model_cfg["num_key_value_heads"]) > 0:
|
int(self.model_cfg["num_key_value_heads"]) > 0:
|
||||||
@@ -165,109 +173,82 @@ class ModelRunner:
|
|||||||
|
|
||||||
pre_max_block_num = (self.args.max_seq_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num
|
pre_max_block_num = (self.args.max_seq_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num
|
||||||
self.share_inputs["block_tables"] = paddle.full(
|
self.share_inputs["block_tables"] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, pre_max_block_num],
|
shape=[self.args.max_batch_size, pre_max_block_num], fill_value=-1, dtype="int32")
|
||||||
fill_value=-1,
|
|
||||||
dtype="int32")
|
|
||||||
|
|
||||||
self.share_inputs['pre_ids'] = paddle.to_tensor(
|
self.share_inputs['pre_ids'] = paddle.to_tensor(
|
||||||
np.full((self.args.max_batch_size, self.args.max_dec_len), -1, dtype='int64'))
|
np.full((self.args.max_batch_size, self.args.max_dec_len), -1, dtype='int64'))
|
||||||
|
|
||||||
tmp_position_ids = paddle.arange(self.args.max_seq_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.args.max_seq_len).reshape((1, -1))
|
||||||
self.share_inputs['rope_emb'] = get_rotary_position_embedding(tmp_position_ids,
|
self.share_inputs['rope_emb'] = get_rotary_position_embedding(tmp_position_ids,
|
||||||
self.args.hidden_size // self.args.num_attention_heads, self.rope_theta, self.rope_scaling)
|
self.args.hidden_size // self.args.num_attention_heads,
|
||||||
|
self.rope_theta, self.rope_scaling)
|
||||||
self.share_inputs['input_ids'] = paddle.full(
|
self.share_inputs['input_ids'] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, self.args.max_seq_len],
|
shape=[self.args.max_batch_size, self.args.max_seq_len],
|
||||||
fill_value=self.pad_token_id,
|
fill_value=self.pad_token_id, dtype='int64')
|
||||||
dtype='int64')
|
self.share_inputs['top_p'] = paddle.full(
|
||||||
self.share_inputs['top_p'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
shape=[self.args.max_batch_size, 1], fill_value=self.top_p, dtype="float32")
|
||||||
fill_value=self.top_p,
|
self.share_inputs['temperature'] = paddle.full(
|
||||||
dtype="float32")
|
shape=[self.args.max_batch_size, 1], fill_value=self.temperature, dtype="float32")
|
||||||
self.share_inputs['temperature'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
|
||||||
fill_value=self.temperature,
|
|
||||||
dtype="float32")
|
|
||||||
self.share_inputs['eos_token_id'] = paddle.to_tensor(
|
self.share_inputs['eos_token_id'] = paddle.to_tensor(
|
||||||
np.zeros((self.eos_tokens_lens, 1)).reshape(-1, 1).astype("int64"))
|
np.zeros((self.eos_tokens_lens, 1)).reshape(-1, 1).astype("int64"))
|
||||||
self.share_inputs['penalty_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
self.share_inputs['penalty_score'] = paddle.full(
|
||||||
fill_value=self.penalty_score,
|
shape=[self.args.max_batch_size, 1], fill_value=self.penalty_score, dtype="float32")
|
||||||
dtype="float32")
|
self.share_inputs['frequency_score'] = paddle.full(
|
||||||
self.share_inputs['frequency_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
shape=[self.args.max_batch_size, 1], fill_value=self.frequency_score, dtype="float32")
|
||||||
fill_value=self.frequency_score,
|
self.share_inputs['presence_score'] = paddle.full(
|
||||||
dtype="float32")
|
shape=[self.args.max_batch_size, 1], fill_value=self.presence_score, dtype="float32")
|
||||||
self.share_inputs['presence_score'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
|
||||||
fill_value=self.presence_score,
|
|
||||||
dtype="float32")
|
|
||||||
self.share_inputs['seq_lens_this_time'] = paddle.full(
|
self.share_inputs['seq_lens_this_time'] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||||
self.share_inputs['seq_lens_encoder'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
self.share_inputs['seq_lens_encoder'] = paddle.full(
|
||||||
fill_value=0,
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||||
dtype="int32")
|
|
||||||
self.share_inputs['step_seq_lens_encoder'] = paddle.full(
|
self.share_inputs['step_seq_lens_encoder'] = paddle.full(
|
||||||
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||||
self.share_inputs['seq_lens_decoder'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
self.share_inputs['seq_lens_decoder'] = paddle.full(
|
||||||
fill_value=0,
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
|
||||||
dtype="int32")
|
self.share_inputs['step_idx'] = paddle.full(
|
||||||
self.share_inputs['step_idx'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int64")
|
||||||
fill_value=0,
|
self.share_inputs['min_length'] = paddle.full(
|
||||||
dtype="int64")
|
shape=[self.args.max_batch_size, 1], fill_value=self.min_length, dtype="int64")
|
||||||
self.share_inputs['min_length'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
self.share_inputs['max_length'] = paddle.full(
|
||||||
fill_value=self.min_length,
|
shape=[self.args.max_batch_size, 1], fill_value=self.max_length, dtype="int64")
|
||||||
dtype="int64")
|
self.share_inputs['not_need_stop'] = paddle.full(
|
||||||
self.share_inputs['max_length'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
shape=[1], fill_value=False, dtype="bool")
|
||||||
fill_value=self.max_length,
|
self.share_inputs['stop_flags'] = paddle.full(
|
||||||
dtype="int64")
|
shape=[self.args.max_batch_size, 1], fill_value=True, dtype="bool")
|
||||||
self.share_inputs['not_need_stop'] = paddle.full(shape=[1],
|
self.share_inputs['stop_nums'] = paddle.full(
|
||||||
fill_value=False,
|
shape=[1], fill_value=self.args.max_batch_size, dtype="int64")
|
||||||
dtype="bool")
|
self.share_inputs['bad_tokens'] = paddle.full(
|
||||||
self.share_inputs['stop_flags'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
shape=[1], fill_value=-1, dtype="int64")
|
||||||
fill_value=True,
|
self.share_inputs['next_tokens'] = paddle.full(
|
||||||
dtype="bool")
|
shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
|
||||||
self.share_inputs['stop_nums'] = paddle.full(shape=[1],
|
self.share_inputs['is_block_step'] = paddle.full(
|
||||||
fill_value=self.args.max_batch_size,
|
shape=[self.args.max_batch_size], fill_value=False, dtype="bool")
|
||||||
dtype="int64")
|
self.share_inputs['encoder_block_lens'] = paddle.full(
|
||||||
self.share_inputs['bad_tokens'] = paddle.full(shape=[1],
|
shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
|
||||||
fill_value=-1,
|
self.share_inputs['step_block_list'] = paddle.full(
|
||||||
dtype="int64")
|
shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
|
||||||
self.share_inputs['next_tokens'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
|
||||||
fill_value=-1,
|
|
||||||
dtype="int64")
|
|
||||||
self.share_inputs['is_block_step'] = paddle.full(shape=[self.args.max_batch_size],
|
|
||||||
fill_value=False,
|
|
||||||
dtype="bool")
|
|
||||||
self.share_inputs['encoder_block_lens'] = paddle.full(shape=[self.args.max_batch_size],
|
|
||||||
fill_value=0,
|
|
||||||
dtype="int32")
|
|
||||||
self.share_inputs['step_block_list'] = paddle.full(shape=[self.args.max_batch_size],
|
|
||||||
fill_value=-1,
|
|
||||||
dtype="int32")
|
|
||||||
self.share_inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32")
|
self.share_inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32")
|
||||||
self.share_inputs['recover_block_list'] = paddle.full(shape=[self.args.max_batch_size],
|
self.share_inputs['recover_block_list'] = paddle.full(
|
||||||
fill_value=-1,
|
shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
|
||||||
dtype="int32")
|
self.share_inputs['recover_lens'] = paddle.full(
|
||||||
self.share_inputs['recover_lens'] = paddle.full(shape=[1],
|
shape=[1], fill_value=0, dtype="int32")
|
||||||
fill_value=0,
|
self.share_inputs['need_block_list'] = paddle.full(
|
||||||
dtype="int32")
|
shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
|
||||||
self.share_inputs['need_block_list'] = paddle.full(shape=[self.args.max_batch_size],
|
self.share_inputs['need_block_len'] = paddle.full(
|
||||||
fill_value=-1,
|
shape=[1], fill_value=0, dtype="int32")
|
||||||
dtype="int32")
|
self.share_inputs['used_list_len'] = paddle.full(
|
||||||
self.share_inputs['need_block_len'] = paddle.full(shape=[1],
|
shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
|
||||||
fill_value=0,
|
self.share_inputs['infer_seed'] = paddle.full(
|
||||||
dtype="int32")
|
shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int64")
|
||||||
self.share_inputs['used_list_len'] = paddle.full(shape=[self.args.max_batch_size],
|
|
||||||
fill_value=0,
|
|
||||||
dtype="int32")
|
|
||||||
self.share_inputs['infer_seed'] = paddle.full(shape=[self.args.max_batch_size, 1],
|
|
||||||
fill_value=0,
|
|
||||||
dtype="int64")
|
|
||||||
free_list = list(range(int(self.args.max_block_num * self.args.block_ratio)))
|
free_list = list(range(int(self.args.max_block_num * self.args.block_ratio)))
|
||||||
self.free_list_len = len(free_list)
|
self.free_list_len = len(free_list)
|
||||||
self.share_inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32")
|
self.share_inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32")
|
||||||
self.share_inputs['free_list_len'] = paddle.full(shape=[1],
|
self.share_inputs['free_list_len'] = paddle.full(
|
||||||
fill_value=self.free_list_len,
|
shape=[1], fill_value=self.free_list_len, dtype="int32")
|
||||||
dtype="int32")
|
|
||||||
|
|
||||||
def dy_input_preprocess(self, tasks):
|
def dy_input_preprocess(self, tasks):
|
||||||
"""
|
"""
|
||||||
动态插入部分额外处理
|
dynamic insertion
|
||||||
"""
|
"""
|
||||||
for i in range(len(tasks)):
|
for i in range(len(tasks)):
|
||||||
task = tasks[i]
|
task = tasks[i]
|
||||||
@@ -309,7 +290,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
def step_cuda(self, seq_lens_this_time):
|
def step_cuda(self, seq_lens_this_time):
|
||||||
"""
|
"""
|
||||||
block调度
|
step cuda
|
||||||
"""
|
"""
|
||||||
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
|
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
|
||||||
self.share_inputs['step_seq_lens_encoder'],
|
self.share_inputs['step_seq_lens_encoder'],
|
||||||
@@ -327,7 +308,11 @@ class ModelRunner:
|
|||||||
|
|
||||||
def initialize_engine_ready_check_flag(self):
|
def initialize_engine_ready_check_flag(self):
|
||||||
"""
|
"""
|
||||||
初始化共享内存中引擎准备就绪标志变量
|
initialize engine ready flag in shared memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
shm_engine_ready_check_flag: engine ready flag
|
||||||
|
engine_ready_check_flag_array: engine ready flag array
|
||||||
"""
|
"""
|
||||||
engine_ready_check_flag = np.zeros([1], dtype=np.int32)
|
engine_ready_check_flag = np.zeros([1], dtype=np.int32)
|
||||||
shm_engine_ready_check_flag = shared_memory.SharedMemory(
|
shm_engine_ready_check_flag = shared_memory.SharedMemory(
|
||||||
@@ -339,7 +324,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
def initialize_engine_live_flag(self):
|
def initialize_engine_live_flag(self):
|
||||||
"""
|
"""
|
||||||
创建用来表明当前infer引擎进程存在的共享内存变量
|
initialize infer live flag in shared memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
infer_live_flag_shm: infer live flag
|
||||||
"""
|
"""
|
||||||
infer_live_flag_shm = shared_memory.SharedMemory(create=True,
|
infer_live_flag_shm = shared_memory.SharedMemory(create=True,
|
||||||
size=1,
|
size=1,
|
||||||
@@ -348,7 +336,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
def initialize_engine_healthy_recorded_time_flag(self):
|
def initialize_engine_healthy_recorded_time_flag(self):
|
||||||
"""
|
"""
|
||||||
初始化共享内存中记录引擎健康的时间戳变量
|
initialize engine healthy recorded time flag in shared memory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
shm_engine_healthy_recorded_time: engine healthy recorded time flag
|
||||||
"""
|
"""
|
||||||
engine_healthy_recorded_time = np.zeros([1], dtype=float)
|
engine_healthy_recorded_time = np.zeros([1], dtype=float)
|
||||||
shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
|
shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
|
||||||
@@ -359,7 +350,9 @@ class ModelRunner:
|
|||||||
return shm_engine_healthy_recorded_time, engine_healthy_recorded_time_array
|
return shm_engine_healthy_recorded_time, engine_healthy_recorded_time_array
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
# 共享内存设置 #
|
"""
|
||||||
|
run infer
|
||||||
|
"""
|
||||||
flag_array = np.zeros([1], dtype=np.int32)
|
flag_array = np.zeros([1], dtype=np.int32)
|
||||||
shm_flag_broadcast = shared_memory.SharedMemory(
|
shm_flag_broadcast = shared_memory.SharedMemory(
|
||||||
name=self.config.get_unique_name("shm_pd_infer_flag_broadcast"))
|
name=self.config.get_unique_name("shm_pd_infer_flag_broadcast"))
|
||||||
@@ -372,7 +365,7 @@ class ModelRunner:
|
|||||||
flag_ready_array = np.ndarray(flag_array.shape,
|
flag_ready_array = np.ndarray(flag_array.shape,
|
||||||
dtype=flag_array.dtype,
|
dtype=flag_array.dtype,
|
||||||
buffer=shm_flag_ready.buf)
|
buffer=shm_flag_ready.buf)
|
||||||
flag_ready_array[self.rank] = 1 # 已初始化完毕
|
flag_ready_array[self.rank] = 1
|
||||||
|
|
||||||
flag_array = np.zeros([1], dtype=np.int32)
|
flag_array = np.zeros([1], dtype=np.int32)
|
||||||
shm_flag_has_block_step = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_has_block_step"))
|
shm_flag_has_block_step = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_has_block_step"))
|
||||||
@@ -386,23 +379,19 @@ class ModelRunner:
|
|||||||
engine_ready_check_flag_array[0] = 1
|
engine_ready_check_flag_array[0] = 1
|
||||||
shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
|
shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
|
||||||
engine_healthy_recorded_time_array[0] = time.time()
|
engine_healthy_recorded_time_array[0] = time.time()
|
||||||
# 创建代表infer存活的共享变量
|
|
||||||
infer_live_flag_shm = self.initialize_engine_live_flag()
|
infer_live_flag_shm = self.initialize_engine_live_flag()
|
||||||
|
|
||||||
infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1],
|
infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1],
|
||||||
fill_value=4,
|
fill_value=4,
|
||||||
dtype="int64")
|
dtype="int64")
|
||||||
|
|
||||||
thread_executor = ThreadPoolExecutor(max_workers=1)
|
thread_executor = ThreadPoolExecutor(max_workers=1)
|
||||||
seq_lens_this_time = None
|
seq_lens_this_time = None
|
||||||
real_bsz = None
|
real_bsz = None
|
||||||
|
|
||||||
while 1:
|
while True:
|
||||||
if use_custom_health_checker:
|
if use_custom_health_checker:
|
||||||
engine_healthy_recorded_time_array[0] = time.time()
|
engine_healthy_recorded_time_array[0] = time.time()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
# 队列不为空, 可取出数据
|
|
||||||
if not self.infer_queue.empty():
|
if not self.infer_queue.empty():
|
||||||
flag_broadcast_array[0] = 1
|
flag_broadcast_array[0] = 1
|
||||||
|
|
||||||
@@ -427,7 +416,6 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.dy_input_preprocess(req_dicts)
|
self.dy_input_preprocess(req_dicts)
|
||||||
# 特殊处理seq_lens
|
|
||||||
seq_lens_this_time = copy.deepcopy(
|
seq_lens_this_time = copy.deepcopy(
|
||||||
self.share_inputs['seq_lens_this_time'][:real_bsz])
|
self.share_inputs['seq_lens_this_time'][:real_bsz])
|
||||||
self.infer_engine.seq_lens_handle.share_external_data(
|
self.infer_engine.seq_lens_handle.share_external_data(
|
||||||
@@ -440,12 +428,10 @@ class ModelRunner:
|
|||||||
|
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
continue
|
continue
|
||||||
self.infer_engine.predictor.run()
|
|
||||||
|
|
||||||
# 自增随机种子,让每次计算的种子不一样
|
self.infer_engine.predictor.run()
|
||||||
self.share_inputs['infer_seed'].add_(infer_seed_increment)
|
self.share_inputs['infer_seed'].add_(infer_seed_increment)
|
||||||
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
|
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
|
||||||
|
|
||||||
if self.free_list_len > 0:
|
if self.free_list_len > 0:
|
||||||
self.step_cuda(seq_lens_this_time)
|
self.step_cuda(seq_lens_this_time)
|
||||||
|
|
||||||
@@ -459,9 +445,6 @@ class InferenceEngine(object):
|
|||||||
mp_degree (int): model parallel size
|
mp_degree (int): model parallel size
|
||||||
"""
|
"""
|
||||||
def __init__(self, model_dir, share_inputs, cache_kvs, config, mp_degree=1):
|
def __init__(self, model_dir, share_inputs, cache_kvs, config, mp_degree=1):
|
||||||
"""
|
|
||||||
初始化模型目录,并设置多进程环境。
|
|
||||||
"""
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
self.mp_degree = mp_degree
|
self.mp_degree = mp_degree
|
||||||
@@ -480,13 +463,14 @@ class InferenceEngine(object):
|
|||||||
self.share_data()
|
self.share_data()
|
||||||
|
|
||||||
def _init_predictor(self):
|
def _init_predictor(self):
|
||||||
"""predictor init"""
|
"""
|
||||||
|
predictor init
|
||||||
|
"""
|
||||||
device_id = self.rank % 8
|
device_id = self.rank % 8
|
||||||
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
|
self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
|
||||||
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
|
self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
|
||||||
config = paddle.inference.Config(self.model_file, self.param_file)
|
config = paddle.inference.Config(self.model_file, self.param_file)
|
||||||
|
|
||||||
# config.enable_memory_optim()
|
|
||||||
config.switch_ir_optim(False)
|
config.switch_ir_optim(False)
|
||||||
config.enable_use_gpu(100, device_id)
|
config.enable_use_gpu(100, device_id)
|
||||||
|
|
||||||
@@ -507,7 +491,7 @@ class InferenceEngine(object):
|
|||||||
)
|
)
|
||||||
dist_config.set_comm_init_config(
|
dist_config.set_comm_init_config(
|
||||||
os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))
|
os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))
|
||||||
# dist_config.set_comm_init_config(os.path.join(Dir_Path + "/config", "rank_mapping.csv"))
|
|
||||||
config.set_dist_config(dist_config)
|
config.set_dist_config(dist_config)
|
||||||
self.predictor = paddle.inference.create_predictor(config)
|
self.predictor = paddle.inference.create_predictor(config)
|
||||||
self.input_names = self.predictor.get_input_names()
|
self.input_names = self.predictor.get_input_names()
|
||||||
@@ -515,7 +499,7 @@ class InferenceEngine(object):
|
|||||||
|
|
||||||
def share_data(self):
|
def share_data(self):
|
||||||
"""
|
"""
|
||||||
分享不拷贝数据
|
share data
|
||||||
"""
|
"""
|
||||||
for name in self.input_names:
|
for name in self.input_names:
|
||||||
if "caches" in name:
|
if "caches" in name:
|
||||||
@@ -542,7 +526,7 @@ class InferenceEngine(object):
|
|||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
"""
|
"""
|
||||||
从命令行解析参数
|
parse args from command line
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
|
parser = argparse.ArgumentParser("FastDeploy LLM Inference")
|
||||||
parser.add_argument('-m',
|
parser.add_argument('-m',
|
||||||
@@ -596,7 +580,7 @@ def parse_args():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""
|
"""
|
||||||
启动推理引擎并进行预测
|
start model runner
|
||||||
"""
|
"""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
model_runner = ModelRunner(args)
|
model_runner = ModelRunner(args)
|
||||||
|
@@ -24,46 +24,71 @@ from server.utils import model_server_logger
|
|||||||
|
|
||||||
class ResourceManager(object):
|
class ResourceManager(object):
|
||||||
"""
|
"""
|
||||||
用于记录和分配引擎的资源
|
record and allocate resources for the engine
|
||||||
"""
|
"""
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.stop_flags = [True] * cfg.max_batch_size
|
self.stop_flags = [True] * cfg.max_batch_size
|
||||||
self.free_list = list(range(cfg.max_block_num - 1, -1, -1))
|
self.free_list = list(range(cfg.max_block_num - 1, -1, -1))
|
||||||
self.tasks_list = [None] * self.cfg.max_batch_size
|
self.tasks_list = [None] * self.cfg.max_batch_size
|
||||||
# 引擎当前的batch情况
|
# current batch status of the engine
|
||||||
self.real_bsz = 0
|
self.real_bsz = 0
|
||||||
model_server_logger.info(f"{self.info()}")
|
model_server_logger.info(f"{self.info()}")
|
||||||
|
|
||||||
def get_required_block_number(self, input_token_num):
|
def get_required_block_number(self, input_token_num):
|
||||||
"""
|
"""
|
||||||
计算需要多少Block资源
|
Calculate Block resources are needed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_token_num (int): input token number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: block number
|
||||||
"""
|
"""
|
||||||
block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
|
block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
|
||||||
return block_num
|
return block_num
|
||||||
|
|
||||||
def get_encoder_block_number(self, input_token_num):
|
def get_encoder_block_number(self, input_token_num):
|
||||||
"""
|
"""
|
||||||
获取编码器所需的block数目
|
get the number of blocks for the encoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_token_num (int): input token number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: encoder block number
|
||||||
"""
|
"""
|
||||||
enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
||||||
return enc_block_num
|
return enc_block_num
|
||||||
|
|
||||||
def get_decoder_block_number(self):
|
def get_decoder_block_number(self):
|
||||||
"""
|
"""
|
||||||
获取解码器所需的block数目
|
get the number of blocks for the decoder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: decoder block number
|
||||||
"""
|
"""
|
||||||
return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
|
||||||
|
|
||||||
def total_block_number(self):
|
def total_block_number(self):
|
||||||
"""
|
"""
|
||||||
返回服务启动时预分配的block数量
|
the number of pre allocated blocks at service startup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: total block number
|
||||||
"""
|
"""
|
||||||
return self.cfg.max_block_num
|
return self.cfg.max_block_num
|
||||||
|
|
||||||
def _get_block_tables(self, input_token_num, required_type="all"):
|
def _get_block_tables(self, input_token_num, required_type="all"):
|
||||||
"""
|
"""
|
||||||
分配显存资源
|
allocate memory resources
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_token_num (int): input token number
|
||||||
|
required_type (str): required type
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: block list
|
||||||
"""
|
"""
|
||||||
if required_type == "all":
|
if required_type == "all":
|
||||||
block_num = self.get_required_block_number(input_token_num)
|
block_num = self.get_required_block_number(input_token_num)
|
||||||
@@ -86,29 +111,43 @@ class ResourceManager(object):
|
|||||||
|
|
||||||
def _recycle_block_tables(self, block_tables):
|
def _recycle_block_tables(self, block_tables):
|
||||||
"""
|
"""
|
||||||
回收显存资源blocks
|
Recycling memory resource blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block_tables (list): block list
|
||||||
"""
|
"""
|
||||||
ori_number = len(self.free_list)
|
ori_number = len(self.free_list)
|
||||||
self.free_list.extend(block_tables)
|
self.free_list.extend(block_tables)
|
||||||
# self.free_list = list(set(self.free_list + block_tables))
|
|
||||||
cur_number = len(self.free_list)
|
cur_number = len(self.free_list)
|
||||||
model_server_logger.info(f"recycle {cur_number - ori_number} blocks.")
|
model_server_logger.info(f"recycle {cur_number - ori_number} blocks.")
|
||||||
|
|
||||||
def available_batch(self):
|
def available_batch(self):
|
||||||
"""
|
"""
|
||||||
引擎当前可用最大Batch
|
available batch size for engine
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: available batch size
|
||||||
"""
|
"""
|
||||||
return np.sum(self.stop_flags)
|
return np.sum(self.stop_flags)
|
||||||
|
|
||||||
def availabel_block_num(self):
|
def availabel_block_num(self):
|
||||||
"""
|
"""
|
||||||
引擎当前可用的block数量
|
available block size for engine
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: available block size
|
||||||
"""
|
"""
|
||||||
return len(self.free_list)
|
return len(self.free_list)
|
||||||
|
|
||||||
def is_resource_sufficient(self, input_token_num):
|
def is_resource_sufficient(self, input_token_num):
|
||||||
"""
|
"""
|
||||||
判断当前可用资源是否满足新的需求
|
check current available resources meet the new requirements
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_token_num (int): input token number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: whether current available resources meet the new requirements
|
||||||
"""
|
"""
|
||||||
if self.available_batch() < 1:
|
if self.available_batch() < 1:
|
||||||
return False
|
return False
|
||||||
@@ -119,11 +158,17 @@ class ResourceManager(object):
|
|||||||
|
|
||||||
def allocate_resources_for_new_tasks(self, tasks):
|
def allocate_resources_for_new_tasks(self, tasks):
|
||||||
"""
|
"""
|
||||||
为新任务分配资源
|
allocate resources for new tasks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list): task list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: processed task list
|
||||||
"""
|
"""
|
||||||
|
|
||||||
allocated_position = 0 # 新任务插入的位置
|
allocated_position = 0
|
||||||
processing_task_index = 0 # 当前正在处理的任务index
|
processing_task_index = 0
|
||||||
processed_tasks = list()
|
processed_tasks = list()
|
||||||
while allocated_position < self.cfg.max_batch_size:
|
while allocated_position < self.cfg.max_batch_size:
|
||||||
if processing_task_index >= len(tasks):
|
if processing_task_index >= len(tasks):
|
||||||
@@ -172,7 +217,7 @@ class ResourceManager(object):
|
|||||||
allocated_position += 1
|
allocated_position += 1
|
||||||
processing_task_index += 1
|
processing_task_index += 1
|
||||||
|
|
||||||
# 统计引擎正在推理时的batch size
|
# batch size when the statistical engine is inferring
|
||||||
for i in range(self.cfg.max_batch_size - 1, -1, -1):
|
for i in range(self.cfg.max_batch_size - 1, -1, -1):
|
||||||
if not self.stop_flags[i]:
|
if not self.stop_flags[i]:
|
||||||
self.real_bsz = i + 1
|
self.real_bsz = i + 1
|
||||||
@@ -184,6 +229,12 @@ class ResourceManager(object):
|
|||||||
return processed_tasks
|
return processed_tasks
|
||||||
|
|
||||||
def info(self):
|
def info(self):
|
||||||
|
"""
|
||||||
|
get resource manager info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: resource manager info
|
||||||
|
"""
|
||||||
info = f"ResourceManager info, " \
|
info = f"ResourceManager info, " \
|
||||||
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
|
f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
|
||||||
f"availabel_block_num: {self.availabel_block_num()}, available_batch: {self.available_batch()}"
|
f"availabel_block_num: {self.availabel_block_num()}, available_batch: {self.available_batch()}"
|
||||||
|
@@ -15,14 +15,9 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from multiprocessing.managers import (AcquirerProxy, BaseManager, ListProxy,
|
||||||
|
Value, ValueProxy)
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from multiprocessing.managers import (
|
|
||||||
AcquirerProxy,
|
|
||||||
BaseManager,
|
|
||||||
ListProxy,
|
|
||||||
Value,
|
|
||||||
ValueProxy,
|
|
||||||
)
|
|
||||||
|
|
||||||
from server.utils import get_logger
|
from server.utils import get_logger
|
||||||
|
|
||||||
@@ -31,7 +26,7 @@ logger = get_logger("infer_server", "task_queue_manager.log")
|
|||||||
|
|
||||||
class QueueManager(BaseManager):
|
class QueueManager(BaseManager):
|
||||||
"""
|
"""
|
||||||
基础类
|
base class for queue manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
@@ -39,12 +34,13 @@ class QueueManager(BaseManager):
|
|||||||
|
|
||||||
class TaskQueueManager(object):
|
class TaskQueueManager(object):
|
||||||
"""
|
"""
|
||||||
管理类
|
task queue manager
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rank=0, mp_num=8, port=56666):
|
def __init__(self, rank=0, mp_num=8, port=56666):
|
||||||
"""
|
"""
|
||||||
初始化函数,用于创建对象时进行初始化操作。
|
Initialization function, used to perform initialization
|
||||||
|
operations when creating objects
|
||||||
"""
|
"""
|
||||||
self.max_get_num = int(os.getenv("ENGINE_MAX_NEED_NUM", 0))
|
self.max_get_num = int(os.getenv("ENGINE_MAX_NEED_NUM", 0))
|
||||||
QueueManager.register('get_list')
|
QueueManager.register('get_list')
|
||||||
@@ -72,7 +68,10 @@ class TaskQueueManager(object):
|
|||||||
|
|
||||||
def empty(self):
|
def empty(self):
|
||||||
"""
|
"""
|
||||||
暴露至推理端,用于判断队列是否为空
|
check the queue is empty for infer
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the queue is empty, otherwise False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return len(self.list) == 0
|
return len(self.list) == 0
|
||||||
@@ -82,7 +81,10 @@ class TaskQueueManager(object):
|
|||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
"""
|
"""
|
||||||
向队列中添加数据
|
put item to queue
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item (any): the item to put into queue
|
||||||
"""
|
"""
|
||||||
self.lock.acquire()
|
self.lock.acquire()
|
||||||
if 0 < self.value.get() < self.total_num:
|
if 0 < self.value.get() < self.total_num:
|
||||||
@@ -100,13 +102,16 @@ class TaskQueueManager(object):
|
|||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""
|
"""
|
||||||
从队列中获取数据
|
get item from queue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: the item from queue
|
||||||
|
bool: True if the queue is empty, otherwise False
|
||||||
"""
|
"""
|
||||||
input_list = []
|
input_list = []
|
||||||
read_finish = False
|
read_finish = False
|
||||||
self.lock.acquire()
|
self.lock.acquire()
|
||||||
if self.value.get() & self.position == 0 and len(self.list) > 0:
|
if self.value.get() & self.position == 0 and len(self.list) > 0:
|
||||||
# 控制进入引擎的输入数量. 默认服务中所有输入都拷贝进引擎一起处理
|
|
||||||
if self.max_get_num > 0:
|
if self.max_get_num > 0:
|
||||||
input_list.extend(self.list[: self.max_get_num])
|
input_list.extend(self.list[: self.max_get_num])
|
||||||
else:
|
else:
|
||||||
@@ -128,10 +133,11 @@ class TaskQueueManager(object):
|
|||||||
|
|
||||||
def launch_queue_service(port, num_workers):
|
def launch_queue_service(port, num_workers):
|
||||||
"""
|
"""
|
||||||
启动进程间通信队列服务
|
Start the process communication queue service
|
||||||
|
|
||||||
port: 监听端口号
|
Args:
|
||||||
num_workers: infer进程的数量
|
port (int): the port to listen
|
||||||
|
num_workers (int): the number of infer process
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"start launch queue service, port:{port}")
|
logger.info(f"start launch queue service, port:{port}")
|
||||||
|
@@ -16,26 +16,24 @@ import os
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from paddlenlp_ops import get_output
|
from paddlenlp_ops import get_output
|
||||||
from server.utils import datetime_diff, model_server_logger, monitor_logger
|
from server.utils import datetime_diff, model_server_logger, monitor_logger
|
||||||
|
|
||||||
|
|
||||||
class TokenProcessor(object):
|
class TokenProcessor(object):
|
||||||
"""
|
"""
|
||||||
持续从Paddle底层引擎队列中获取生成Token/Score,并进行处理
|
get Token/Score from Paddle inference engine
|
||||||
"""
|
"""
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
import paddle
|
import paddle
|
||||||
paddle.device.set_device("cpu")
|
paddle.device.set_device("cpu")
|
||||||
# 服务配置
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
# 引擎状态
|
|
||||||
self.resource_manager = None
|
self.resource_manager = None
|
||||||
# 记录每个请求的当前所有生成Token
|
# record all tokens for each request
|
||||||
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]
|
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]
|
||||||
|
|
||||||
self.tokens_counter = Counter()
|
self.tokens_counter = Counter()
|
||||||
@@ -51,14 +49,17 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
def set_resource_manager(self, resource_manager):
|
def set_resource_manager(self, resource_manager):
|
||||||
"""
|
"""
|
||||||
设置ResourceManager
|
set ResourceManager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_manager (ResourceManager)
|
||||||
"""
|
"""
|
||||||
assert self.resource_manager is None, "The resource manager is not None, cannot set again."
|
assert self.resource_manager is None, "The resource manager is not None, cannot set again."
|
||||||
self.resource_manager = resource_manager
|
self.resource_manager = resource_manager
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""
|
"""
|
||||||
启动子线程,持续处理生成Token
|
start thread to get tokens
|
||||||
"""
|
"""
|
||||||
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
assert self.resource_manager is not None, "The resource manager is None, cannot run."
|
||||||
if self.worker is not None:
|
if self.worker is not None:
|
||||||
@@ -70,7 +71,7 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
def process_sampling_results(self):
|
def process_sampling_results(self):
|
||||||
"""
|
"""
|
||||||
循环获取输出,并处理数据
|
read tokens from paddle inference engine and process
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@@ -86,7 +87,11 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
def postprocess(self, batch_result, exist_finished_task=False):
|
def postprocess(self, batch_result, exist_finished_task=False):
|
||||||
"""
|
"""
|
||||||
生成单步结果后处理函数
|
single post-processing function
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_result (list): batch results
|
||||||
|
exist_finished_task (bool): whether there is a finished task
|
||||||
"""
|
"""
|
||||||
result_dir = "./generate_token_results"
|
result_dir = "./generate_token_results"
|
||||||
if not os.path.exists(result_dir):
|
if not os.path.exists(result_dir):
|
||||||
@@ -98,7 +103,16 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
def _get_single_result(self, i, task_id, token_id, task):
|
def _get_single_result(self, i, task_id, token_id, task):
|
||||||
"""
|
"""
|
||||||
处理单步生成结果
|
processing single results
|
||||||
|
|
||||||
|
Args:
|
||||||
|
i (int): batch index
|
||||||
|
task_id (str): task id
|
||||||
|
token_id (int): token id
|
||||||
|
task (dict): task information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: result
|
||||||
"""
|
"""
|
||||||
inference_time_cost = time.time() - task["inference_start_time"]
|
inference_time_cost = time.time() - task["inference_start_time"]
|
||||||
task["inference_time_cost"] = inference_time_cost
|
task["inference_time_cost"] = inference_time_cost
|
||||||
@@ -114,7 +128,7 @@ class TokenProcessor(object):
|
|||||||
"return_all_tokens": task.get("return_all_tokens", False),
|
"return_all_tokens": task.get("return_all_tokens", False),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 收集benchmark信息
|
# get benchmark msg
|
||||||
if task.get("benchmark"):
|
if task.get("benchmark"):
|
||||||
keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time",
|
keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time",
|
||||||
"inference_start_time", "inference_current_step_time"]
|
"inference_start_time", "inference_current_step_time"]
|
||||||
@@ -122,14 +136,13 @@ class TokenProcessor(object):
|
|||||||
if key in task:
|
if key in task:
|
||||||
result[key] = str(task[key])
|
result[key] = str(task[key])
|
||||||
|
|
||||||
# 生成结束符时,额外填充部分信息
|
# fill some extra information
|
||||||
if token_id in task["eos_token_ids"]:
|
if token_id in task["eos_token_ids"]:
|
||||||
result["is_end"] = 1
|
result["is_end"] = 1
|
||||||
result["token_ids"] = []
|
result["token_ids"] = []
|
||||||
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
|
result["tokens_all_num"] = len(self.all_tokens[i]) + 1
|
||||||
result["tokens_all_ids"] = self.all_tokens[i]
|
result["tokens_all_ids"] = self.all_tokens[i]
|
||||||
|
|
||||||
# 生成请求的完整日志,用于平台监控
|
|
||||||
info_dict = {}
|
info_dict = {}
|
||||||
info_dict["req_id"] = task["req_id"]
|
info_dict["req_id"] = task["req_id"]
|
||||||
info_dict["input_token_num"] = len(task["input_ids"])
|
info_dict["input_token_num"] = len(task["input_ids"])
|
||||||
@@ -149,7 +162,7 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
def _recycle_resources(self, task_id, index, task):
|
def _recycle_resources(self, task_id, index, task):
|
||||||
"""
|
"""
|
||||||
对于已完成的任务,回收资源
|
recycle resources
|
||||||
"""
|
"""
|
||||||
self.resource_manager.stop_flags[index] = True
|
self.resource_manager.stop_flags[index] = True
|
||||||
self.resource_manager.tasks_list[index] = None
|
self.resource_manager.tasks_list[index] = None
|
||||||
@@ -158,29 +171,15 @@ class TokenProcessor(object):
|
|||||||
del self.tokens_counter[task_id]
|
del self.tokens_counter[task_id]
|
||||||
self.all_tokens[index] = list()
|
self.all_tokens[index] = list()
|
||||||
|
|
||||||
def _recycle_beam_resources(self, task_id_list, index_list, block_tables):
|
|
||||||
assert len(task_id_list) == len(index_list), \
|
|
||||||
f"{len(task_id_list)} task_id don't equal to {len(index_list)} index"
|
|
||||||
self.resource_manager._recycle_block_tables(block_tables)
|
|
||||||
for i in range(len(task_id_list)):
|
|
||||||
task_id = task_id_list[i]
|
|
||||||
index = index_list[i]
|
|
||||||
self.resource_manager.tasks_list[index] = None
|
|
||||||
self.resource_manager.stop_flags[index] = True
|
|
||||||
if task_id in self.tokens_counter:
|
|
||||||
del self.tokens_counter[task_id]
|
|
||||||
self.all_tokens[index] = list()
|
|
||||||
|
|
||||||
def _process_batch_output(self):
|
def _process_batch_output(self):
|
||||||
"""
|
"""
|
||||||
处理一个batch的输出结果
|
batch post-processing function
|
||||||
"""
|
"""
|
||||||
tokens = self.output_tokens.numpy()
|
tokens = self.output_tokens.numpy()
|
||||||
batch = self.output_tokens[1, 0]
|
batch = self.output_tokens[1, 0]
|
||||||
tokens = tokens[2:batch + 2]
|
tokens = tokens[2:batch + 2]
|
||||||
|
|
||||||
batch_result = list()
|
batch_result = list()
|
||||||
# 用于判断当前此批结果中是否存在已完成的任务
|
|
||||||
exist_finished_task = False
|
exist_finished_task = False
|
||||||
for i in range(batch):
|
for i in range(batch):
|
||||||
if self.resource_manager.stop_flags[i]:
|
if self.resource_manager.stop_flags[i]:
|
||||||
@@ -212,7 +211,7 @@ class TokenProcessor(object):
|
|||||||
|
|
||||||
class WarmUpTokenProcessor(TokenProcessor):
|
class WarmUpTokenProcessor(TokenProcessor):
|
||||||
"""
|
"""
|
||||||
创建warm up服务的Processor
|
Warmup Processor
|
||||||
"""
|
"""
|
||||||
def __init__(self, cfg):
|
def __init__(self, cfg):
|
||||||
super().__init__(cfg)
|
super().__init__(cfg)
|
||||||
@@ -224,7 +223,7 @@ class WarmUpTokenProcessor(TokenProcessor):
|
|||||||
|
|
||||||
def process_sampling_results(self):
|
def process_sampling_results(self):
|
||||||
"""
|
"""
|
||||||
循环获取输出,并处理数据
|
get output from model and process it
|
||||||
"""
|
"""
|
||||||
while self._is_running:
|
while self._is_running:
|
||||||
try:
|
try:
|
||||||
@@ -238,6 +237,9 @@ class WarmUpTokenProcessor(TokenProcessor):
|
|||||||
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
|
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
"""
|
||||||
|
stop warm up thread
|
||||||
|
"""
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
self.worker.join()
|
self.worker.join()
|
||||||
model_server_logger.info("warm up thread stop")
|
model_server_logger.info("warm up thread stop")
|
||||||
|
@@ -27,14 +27,12 @@ from tritonclient import utils as triton_utils
|
|||||||
|
|
||||||
|
|
||||||
class Req(BaseModel):
|
class Req(BaseModel):
|
||||||
"""请求参数的类"""
|
|
||||||
# 传入模型服务的请求参数
|
|
||||||
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||||
input_ids: Optional[List[int]] = None
|
input_ids: Optional[List[int]] = None
|
||||||
text: Optional[str] = None
|
text: Optional[str] = None
|
||||||
messages: Optional[List] = None
|
messages: Optional[List] = None
|
||||||
max_dec_len: Optional[int] = None
|
max_dec_len: Optional[int] = None
|
||||||
seq_len: Optional[int] = None # 保留seq_len为了兼容支持
|
seq_len: Optional[int] = None
|
||||||
min_dec_len: Optional[int] = None
|
min_dec_len: Optional[int] = None
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
topp: Optional[float] = None
|
topp: Optional[float] = None
|
||||||
@@ -45,12 +43,19 @@ class Req(BaseModel):
|
|||||||
return_all_tokens: Optional[bool] = None
|
return_all_tokens: Optional[bool] = None
|
||||||
eos_token_ids: Optional[List[int]] = None
|
eos_token_ids: Optional[List[int]] = None
|
||||||
benchmark: bool = False
|
benchmark: bool = False
|
||||||
# http服务使用的请求参数
|
return_usage: Optional[bool] = False
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
timeout: int = 300
|
timeout: int = 300
|
||||||
|
|
||||||
def to_dict_for_infer(self):
|
def to_dict_for_infer(self):
|
||||||
"""将请求参数转化为字典,去掉为None的字段,避免传递给模型服务出错"""
|
"""
|
||||||
|
Convert the request parameters into a dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: request parameters in dict format
|
||||||
|
"""
|
||||||
|
self.compatible_with_OpenAI()
|
||||||
|
|
||||||
req_dict = {}
|
req_dict = {}
|
||||||
for key, value in self.dict().items():
|
for key, value in self.dict().items():
|
||||||
if value is not None:
|
if value is not None:
|
||||||
@@ -60,23 +65,24 @@ class Req(BaseModel):
|
|||||||
|
|
||||||
def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
|
def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
|
||||||
"""
|
"""
|
||||||
基于Triton推理服务的聊天补全结果的生成器。
|
Chat completion generator based on Triton inference service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
infer_grpc_url (str): Triton推理服务的gRPC URL。
|
infer_grpc_url (str): Triton gRPC URL。
|
||||||
req (Request): 聊天补全请求。
|
req (Request): request parameters
|
||||||
yield_json (bool): 是否返回json格式,否则返回Resp类
|
yield_json (bool): Whether to return the result in json format
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 聊天补全结果的生成器。
|
dict: chat completion result.
|
||||||
如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
"""
|
"""
|
||||||
class _TritonOutputData:
|
class _TritonOutputData:
|
||||||
"""接收Triton服务返回的数据"""
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._completed_requests = queue.Queue()
|
self._completed_requests = queue.Queue()
|
||||||
|
|
||||||
def _triton_callback(output_data, result, error):
|
def _triton_callback(output_data, result, error):
|
||||||
"""Triton客户端的回调函数"""
|
"""Triton callback function"""
|
||||||
if error:
|
if error:
|
||||||
output_data._completed_requests.put(error)
|
output_data._completed_requests.put(error)
|
||||||
else:
|
else:
|
||||||
@@ -88,7 +94,6 @@ def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -
|
|||||||
else:
|
else:
|
||||||
return resp_dict
|
return resp_dict
|
||||||
|
|
||||||
# 准备请求数据
|
|
||||||
timeout = req.timeout
|
timeout = req.timeout
|
||||||
req_id = req.req_id
|
req_id = req.req_id
|
||||||
req_dict = req.to_dict_for_infer()
|
req_dict = req.to_dict_for_infer()
|
||||||
@@ -99,16 +104,13 @@ def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -
|
|||||||
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
outputs = [grpcclient.InferRequestedOutput("OUT")]
|
||||||
output_data = _TritonOutputData()
|
output_data = _TritonOutputData()
|
||||||
|
|
||||||
# 建立连接
|
|
||||||
with grpcclient.InferenceServerClient(url=infer_grpc_url, verbose=False) as triton_client:
|
with grpcclient.InferenceServerClient(url=infer_grpc_url, verbose=False) as triton_client:
|
||||||
triton_client.start_stream(callback=partial(_triton_callback, output_data))
|
triton_client.start_stream(callback=partial(_triton_callback, output_data))
|
||||||
|
|
||||||
# 发送请求
|
|
||||||
triton_client.async_stream_infer(model_name="model",
|
triton_client.async_stream_infer(model_name="model",
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
request_id=req_dict['req_id'],
|
request_id=req_dict['req_id'],
|
||||||
outputs=outputs)
|
outputs=outputs)
|
||||||
# 处理返回结果
|
|
||||||
while True:
|
while True:
|
||||||
output_item = output_data._completed_requests.get(timeout=timeout)
|
output_item = output_data._completed_requests.get(timeout=timeout)
|
||||||
if type(output_item) == triton_utils.InferenceServerException:
|
if type(output_item) == triton_utils.InferenceServerException:
|
||||||
@@ -126,38 +128,35 @@ def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -
|
|||||||
if (result.get("error_msg") or result.get("error_code")) or result.get("is_end") == 1:
|
if (result.get("error_msg") or result.get("error_code")) or result.get("is_end") == 1:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 手动关闭连接
|
|
||||||
triton_client.stop_stream()
|
triton_client.stop_stream()
|
||||||
triton_client.close()
|
triton_client.close()
|
||||||
|
|
||||||
def chat_completion_result(infer_grpc_url: str, req: Req) -> Dict:
|
def chat_completion_result(infer_grpc_url: str, req: Req) -> Dict:
|
||||||
"""
|
"""
|
||||||
获取非流式生成结果
|
Chat completion result with not streaming mode
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
infer_grpc_url (str): gRPC服务地址
|
infer_grpc_url (str): Triton gRPC URL
|
||||||
req (Req): 请求参数对象
|
req (Req): request parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: 聊天补全结果的生成器。
|
dict: chat completion result.
|
||||||
如果正常,返回{'result': xxx, 'error_msg': '', 'error_code': 0}
|
Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
如果异常,返回{'result': '', 'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
"""
|
"""
|
||||||
result = None
|
result = ""
|
||||||
error_resp = None
|
error_resp = None
|
||||||
for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
|
for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
|
||||||
if resp.get("error_msg") or resp.get("error_code"):
|
if resp.get("error_msg") or resp.get("error_code"):
|
||||||
error_resp = resp
|
error_resp = resp
|
||||||
error_resp["result"] = ""
|
error_resp["result"] = ""
|
||||||
else:
|
else:
|
||||||
if resp.get('is_end') == 1:
|
result += resp.get("token")
|
||||||
result = resp
|
usage = resp.get("usage", None)
|
||||||
for key in ['token', 'is_end', 'send_idx', 'return_all_tokens', 'token']:
|
|
||||||
if key in result:
|
if error_resp:
|
||||||
del result[key]
|
return error_resp
|
||||||
if not result:
|
response = {'result': result, 'error_msg': '', 'error_code': 0}
|
||||||
error_resp = {
|
if req.return_usage:
|
||||||
"error_msg": "HTTP parsing data error",
|
response["usage"] = usage
|
||||||
"error_code": 500,
|
return response
|
||||||
"result": "",
|
|
||||||
"is_end": 1,
|
|
||||||
}
|
|
||||||
return error_resp if error_resp else result
|
|
||||||
|
@@ -18,27 +18,25 @@ import os
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from server.http_server.api import (
|
from server.http_server.api import (Req, chat_completion_generator,
|
||||||
Req,
|
chat_completion_result)
|
||||||
chat_completion_generator,
|
|
||||||
chat_completion_result,
|
|
||||||
)
|
|
||||||
from server.utils import http_server_logger
|
from server.utils import http_server_logger
|
||||||
|
|
||||||
http_server_logger.info(f"create fastapi app...")
|
http_server_logger.info(f"create fastapi app...")
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
def create_chat_completion(req: Req):
|
def create_chat_completion(req: Req):
|
||||||
"""
|
"""
|
||||||
服务端路由函数
|
HTTP Server for chat completion
|
||||||
返回:
|
Return:
|
||||||
如果stream为True,流式返回
|
In Stream:
|
||||||
如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
如果stream为False,非流式返回
|
Not In Stream:
|
||||||
如果正常,返回{'result': xxx, 'error_msg': '', 'error_code': 0}
|
Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
|
||||||
如果异常,返回{'result': '', 'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
|
Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
http_server_logger.info(f"receive request: {req.req_id}")
|
http_server_logger.info(f"receive request: {req.req_id}")
|
||||||
@@ -59,11 +57,12 @@ def create_chat_completion(req: Req):
|
|||||||
http_server_logger.info(f"finish request: {req.req_id}")
|
http_server_logger.info(f"finish request: {req.req_id}")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def launch_http_server(port: int, workers: int) -> None:
|
def launch_http_server(port: int, workers: int) -> None:
|
||||||
"""
|
"""
|
||||||
启动http服务
|
launch http server
|
||||||
"""
|
"""
|
||||||
http_server_logger.info(f"launch http server... port: {port}, workers: {workers}")
|
http_server_logger.info(f"launch http server with port: {port}, workers: {workers}")
|
||||||
try:
|
try:
|
||||||
uvicorn.run(app="server.http_server.app:app",
|
uvicorn.run(app="server.http_server.app:app",
|
||||||
host='0.0.0.0',
|
host='0.0.0.0',
|
||||||
@@ -73,13 +72,14 @@ def launch_http_server(port: int, workers: int) -> None:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
http_server_logger.error(f"launch http server error, {e}")
|
http_server_logger.error(f"launch http server error, {e}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""main函数"""
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", default=9904, type=int, help="port to the http server")
|
parser.add_argument("--port", default=9904, type=int, help="port to the http server")
|
||||||
parser.add_argument("--workers", default=1, type=int, help="set the number of workers for the http service")
|
parser.add_argument("--workers", default=1, type=int, help="set the number of workers for the http service")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
launch_http_server(port=args.port, workers=args.workers)
|
launch_http_server(port=args.port, workers=args.workers)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@@ -26,10 +26,7 @@ from collections import Counter, deque
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from server.checker import (
|
from server.checker import add_default_params, check_basic_params
|
||||||
add_default_params,
|
|
||||||
check_basic_params,
|
|
||||||
)
|
|
||||||
from server.engine import engine
|
from server.engine import engine
|
||||||
from server.engine.config import Config
|
from server.engine.config import Config
|
||||||
from server.utils import error_logger, model_server_logger
|
from server.utils import error_logger, model_server_logger
|
||||||
@@ -50,7 +47,7 @@ if sys.stdout.encoding is None:
|
|||||||
|
|
||||||
class TritonConfig(Config):
|
class TritonConfig(Config):
|
||||||
"""
|
"""
|
||||||
Triton Inference Server额外增加的配置参数
|
Triton Inference Server config
|
||||||
"""
|
"""
|
||||||
def __init__(self, base_config):
|
def __init__(self, base_config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -60,16 +57,13 @@ class TritonConfig(Config):
|
|||||||
|
|
||||||
class TritonTokenProcessor(engine.TokenProcessor):
|
class TritonTokenProcessor(engine.TokenProcessor):
|
||||||
"""
|
"""
|
||||||
创建Triton服务的Processor
|
initialize Triton Processor
|
||||||
"""
|
"""
|
||||||
def __init__(self, cfg, triton_server):
|
def __init__(self, cfg, triton_server):
|
||||||
super().__init__(cfg)
|
super().__init__(cfg)
|
||||||
self.triton_server = triton_server
|
self.triton_server = triton_server
|
||||||
# 缓存的结果
|
|
||||||
self.cached_generated_tokens = queue.Queue()
|
self.cached_generated_tokens = queue.Queue()
|
||||||
# Token缓存,针对部分特殊Token累积后再发送
|
|
||||||
self.token_buffer = dict()
|
self.token_buffer = dict()
|
||||||
# Score缓存,针对部分特殊Token累积后再发送
|
|
||||||
self.score_buffer = dict()
|
self.score_buffer = dict()
|
||||||
|
|
||||||
self.push_mode_sender_thread = threading.Thread(target=self._push_mode_sender_thread, args=())
|
self.push_mode_sender_thread = threading.Thread(target=self._push_mode_sender_thread, args=())
|
||||||
@@ -77,6 +71,9 @@ class TritonTokenProcessor(engine.TokenProcessor):
|
|||||||
self.push_mode_sender_thread.start()
|
self.push_mode_sender_thread.start()
|
||||||
|
|
||||||
def _push_mode_sender_thread(self):
|
def _push_mode_sender_thread(self):
|
||||||
|
"""
|
||||||
|
push mode sender thread
|
||||||
|
"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
batch_result = self.cached_generated_tokens.get()
|
batch_result = self.cached_generated_tokens.get()
|
||||||
@@ -84,24 +81,26 @@ class TritonTokenProcessor(engine.TokenProcessor):
|
|||||||
req_id = result["req_id"]
|
req_id = result["req_id"]
|
||||||
is_end = result.get("is_end", 0)
|
is_end = result.get("is_end", 0)
|
||||||
return_all_tokens = result.get("return_all_tokens", False)
|
return_all_tokens = result.get("return_all_tokens", False)
|
||||||
# 非流式返回下仅返回最后一个Token结果
|
|
||||||
if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming):
|
if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming):
|
||||||
continue
|
continue
|
||||||
if return_all_tokens and "topk_tokens" in result:
|
if return_all_tokens and "topk_tokens" in result:
|
||||||
del result["topk_tokens"]
|
del result["topk_tokens"]
|
||||||
result = self.triton_server.data_processor.process_response(result)
|
result = self.triton_server.data_processor.process_response(result)
|
||||||
|
if "usage" in result:
|
||||||
|
result["usage"]["prompt_tokens"] = self.triton_server.task_info[req_id]["prompt_tokens"]
|
||||||
model_server_logger.debug(f"Send result to client under push mode: {result}")
|
model_server_logger.debug(f"Send result to client under push mode: {result}")
|
||||||
with self.triton_server.thread_lock:
|
with self.triton_server.thread_lock:
|
||||||
_send_result([result], self.triton_server.response_sender[req_id], is_end)
|
_send_result([result], self.triton_server.response_sender[req_id], is_end)
|
||||||
if is_end == 1:
|
if is_end == 1:
|
||||||
del self.triton_server.response_sender[req_id]
|
del self.triton_server.response_sender[req_id]
|
||||||
|
del self.triton_server.task_info[req_id]
|
||||||
self.triton_server._update_metrics()
|
self.triton_server._update_metrics()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
|
model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
|
||||||
|
|
||||||
def postprocess(self, batch_result, exist_finished_task=False):
|
def postprocess(self, batch_result, exist_finished_task=False):
|
||||||
"""
|
"""
|
||||||
生成单步结果后处理函数
|
single postprocess for triton
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.cached_generated_tokens.put(batch_result)
|
self.cached_generated_tokens.put(batch_result)
|
||||||
@@ -113,25 +112,24 @@ class TritonTokenProcessor(engine.TokenProcessor):
|
|||||||
|
|
||||||
class TritonServer(object):
|
class TritonServer(object):
|
||||||
"""
|
"""
|
||||||
Triton框架服务实现
|
Triton Server
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def initialize(self, args):
|
def initialize(self, args):
|
||||||
"""
|
"""
|
||||||
Triton服务初始化
|
Triton initialization
|
||||||
"""
|
"""
|
||||||
# 开启探活服务
|
# start health checker
|
||||||
use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
|
use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
|
||||||
# 环境变量USE_CUSTOM_HEALTH_CHECKER:控制是否使用自定义的探活接口
|
# if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false
|
||||||
# 使用自定义的探活接口时候,tritonserver自身的探活服务需要被关闭,当USE_CUSTOM_HEALTH_CHECKER为1时,需要--allow-http设置为false
|
# else use tritonserver's health checker, need set --http-port=${HTTP_PORT}
|
||||||
# 当USE_CUSTOM_HEALTH_CHECKER为0时,tritonserver自身的探活服务需要打开,设置--http-port=${HTTP_PORT}
|
|
||||||
if use_custom_health_checker:
|
if use_custom_health_checker:
|
||||||
http_port = os.getenv("HTTP_PORT")
|
http_port = os.getenv("HTTP_PORT")
|
||||||
if http_port is None:
|
if http_port is None:
|
||||||
raise Exception("HTTP_PORT must be set")
|
raise Exception("HTTP_PORT must be set")
|
||||||
from server.triton_server_helper import start_health_checker
|
from server.triton_server_helper import start_health_checker
|
||||||
multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start()
|
multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start()
|
||||||
time.sleep(1) # 等待1s,保证需要的共享内存已经创建
|
time.sleep(1)
|
||||||
|
|
||||||
model_config = json.loads(args["model_config"])
|
model_config = json.loads(args["model_config"])
|
||||||
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
|
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
|
||||||
@@ -142,7 +140,7 @@ class TritonServer(object):
|
|||||||
enable decoupled transaction policy in model configuration to
|
enable decoupled transaction policy in model configuration to
|
||||||
serve this model""".format(args["model_name"]))
|
serve this model""".format(args["model_name"]))
|
||||||
|
|
||||||
# 添加metrics指标,可以通过 METRICS_PORT 获取服务状态
|
# add metrics,use METRICS_PORT get server metrics
|
||||||
self.metric_family = pb_utils.MetricFamily(
|
self.metric_family = pb_utils.MetricFamily(
|
||||||
name="inference_server_metrics",
|
name="inference_server_metrics",
|
||||||
description="Metrics for monitoring inference server status",
|
description="Metrics for monitoring inference server status",
|
||||||
@@ -165,15 +163,14 @@ class TritonServer(object):
|
|||||||
labels={"available_resource": "available_resource"}),
|
labels={"available_resource": "available_resource"}),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Triton服务所需变量
|
# response_sender thread lock
|
||||||
# response_sender的线程锁,避免多线程访问或读写时的问题
|
|
||||||
self.thread_lock = threading.Lock()
|
self.thread_lock = threading.Lock()
|
||||||
|
|
||||||
base_config = Config()
|
base_config = Config()
|
||||||
self.cfg = TritonConfig(base_config)
|
self.cfg = TritonConfig(base_config)
|
||||||
self.cfg.print(file="log/fastdeploy_init.info")
|
self.cfg.print(file="log/fastdeploy_init.info")
|
||||||
|
|
||||||
# 初始化底层引擎
|
# init engine
|
||||||
self.token_processor = TritonTokenProcessor(self.cfg, self)
|
self.token_processor = TritonTokenProcessor(self.cfg, self)
|
||||||
self.engine = engine.Engine(self.cfg, self.token_processor)
|
self.engine = engine.Engine(self.cfg, self.token_processor)
|
||||||
model_server_logger.info("Creat engine...")
|
model_server_logger.info("Creat engine...")
|
||||||
@@ -186,7 +183,8 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def execute(self, requests):
|
def execute(self, requests):
|
||||||
"""
|
"""
|
||||||
Triton服务主函数,处理Triton框架接收的请求
|
Triton service main function,
|
||||||
|
handling requests received by the Triton framework
|
||||||
"""
|
"""
|
||||||
if len(requests) != 1:
|
if len(requests) != 1:
|
||||||
raise pb_utils.TritonModelException(
|
raise pb_utils.TritonModelException(
|
||||||
@@ -202,7 +200,7 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
"""
|
"""
|
||||||
Triton服务退出函数
|
Triton service exit function
|
||||||
"""
|
"""
|
||||||
model_server_logger.info("Triton service will be terminated...")
|
model_server_logger.info("Triton service will be terminated...")
|
||||||
wait_time = 300
|
wait_time = 300
|
||||||
@@ -226,7 +224,6 @@ class TritonServer(object):
|
|||||||
self.data_processor = DataProcessor()
|
self.data_processor = DataProcessor()
|
||||||
model_server_logger.info("create data processor success")
|
model_server_logger.info("create data processor success")
|
||||||
|
|
||||||
# 是否开启HTTP协议支持
|
|
||||||
if self.cfg.push_mode_http_port < 0:
|
if self.cfg.push_mode_http_port < 0:
|
||||||
model_server_logger.info("HTTP server for push mode is disabled.")
|
model_server_logger.info("HTTP server for push mode is disabled.")
|
||||||
else:
|
else:
|
||||||
@@ -251,11 +248,9 @@ class TritonServer(object):
|
|||||||
model_server_logger.error(error_msg)
|
model_server_logger.error(error_msg)
|
||||||
model_server_logger.info("init push server success")
|
model_server_logger.info("init push server success")
|
||||||
|
|
||||||
# 需要维护每个请求的通信句柄
|
|
||||||
self.response_sender = dict()
|
self.response_sender = dict()
|
||||||
# 请求队列,从左侧插入,从右侧取出
|
self.task_info = dict()
|
||||||
self.cached_task_deque = deque()
|
self.cached_task_deque = deque()
|
||||||
# 持续监控引擎和请求队列,当引擎有资源时,从请求队列中获取数据,插入到引擎内
|
|
||||||
self.enable_insert_task_push_mode = True
|
self.enable_insert_task_push_mode = True
|
||||||
self.insert_task_to_engine_thread = threading.Thread(
|
self.insert_task_to_engine_thread = threading.Thread(
|
||||||
target=self._insert_task_push_mode, args=())
|
target=self._insert_task_push_mode, args=())
|
||||||
@@ -264,10 +259,13 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def _process_task_push_mode(self, tasks, current_response_sender):
|
def _process_task_push_mode(self, tasks, current_response_sender):
|
||||||
"""
|
"""
|
||||||
针对推模式,对请求进行检查,如果没问题则插入到cached_task_deque中。
|
check request and insert into cached_task_deque
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks (list): list of request
|
||||||
|
current_response_sender: response sender for current request
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 基础检查,如果检查失败,则直接返回错误信息
|
|
||||||
tik = time.time()
|
tik = time.time()
|
||||||
req_id = tasks[0]["req_id"]
|
req_id = tasks[0]["req_id"]
|
||||||
cached_task_num = len(self.cached_task_deque)
|
cached_task_num = len(self.cached_task_deque)
|
||||||
@@ -299,17 +297,14 @@ class TritonServer(object):
|
|||||||
_send_error(error_msg, current_response_sender, req_id=req_id)
|
_send_error(error_msg, current_response_sender, req_id=req_id)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 添加默认参数
|
|
||||||
task = add_default_params(task)
|
task = add_default_params(task)
|
||||||
|
|
||||||
# 拼接和tokenizer处理,默认支持截断
|
|
||||||
if int(task.get("enable_text_truncate", 1)):
|
if int(task.get("enable_text_truncate", 1)):
|
||||||
real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800)
|
real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800)
|
||||||
task = self.data_processor.process_request(task, max_seq_len=real_seq_len)
|
task = self.data_processor.process_request(task, max_seq_len=real_seq_len)
|
||||||
else:
|
else:
|
||||||
task = self.data_processor.process_request(task)
|
task = self.data_processor.process_request(task)
|
||||||
|
|
||||||
# 检查输入长度
|
|
||||||
input_ids_len = len(task["input_ids"])
|
input_ids_len = len(task["input_ids"])
|
||||||
if "max_dec_len" not in task:
|
if "max_dec_len" not in task:
|
||||||
task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit)
|
task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit)
|
||||||
@@ -336,8 +331,8 @@ class TritonServer(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
with self.thread_lock:
|
with self.thread_lock:
|
||||||
# 插入缓存队列
|
|
||||||
self.response_sender[task_id] = current_response_sender
|
self.response_sender[task_id] = current_response_sender
|
||||||
|
self.task_info[task_id] = {"prompt_tokens": input_ids_len}
|
||||||
|
|
||||||
task["preprocess_end_time"] = datetime.now()
|
task["preprocess_end_time"] = datetime.now()
|
||||||
self.cached_task_deque.appendleft(task)
|
self.cached_task_deque.appendleft(task)
|
||||||
@@ -352,10 +347,8 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def _insert_task_push_mode(self):
|
def _insert_task_push_mode(self):
|
||||||
"""
|
"""
|
||||||
推push模式下的持续处理缓存task的线程,一旦有资源将缓存task插入到引擎中。
|
Insert task to engine thread, monitor cached_task_deque.
|
||||||
1. 所有接收到的请求会先插入到cached_task_deque
|
if the engine has resource, insert task to engine
|
||||||
2. _insert_task_push_mode线程持续监控引擎
|
|
||||||
3. 一旦有资源可用,从cached_task_deque取出数据,提交给引擎
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
while self.enable_insert_task_push_mode:
|
while self.enable_insert_task_push_mode:
|
||||||
@@ -384,7 +377,6 @@ class TritonServer(object):
|
|||||||
i_bs += 1
|
i_bs += 1
|
||||||
if i_bs >= self.cfg.max_batch_size:
|
if i_bs >= self.cfg.max_batch_size:
|
||||||
break
|
break
|
||||||
# 此处无需加锁,execute中插入cached_task_deque的方向与-1的方向不同
|
|
||||||
input_token_num = len(self.cached_task_deque[-1]["input_ids"])
|
input_token_num = len(self.cached_task_deque[-1]["input_ids"])
|
||||||
if not self.engine.is_resource_sufficient(input_token_num):
|
if not self.engine.is_resource_sufficient(input_token_num):
|
||||||
break
|
break
|
||||||
@@ -405,7 +397,7 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def _update_metrics(self):
|
def _update_metrics(self):
|
||||||
"""
|
"""
|
||||||
更新监控指标
|
update metrics
|
||||||
"""
|
"""
|
||||||
block_num = self.engine.available_block_num()
|
block_num = self.engine.available_block_num()
|
||||||
batch_size = self.engine.available_batch()
|
batch_size = self.engine.available_batch()
|
||||||
@@ -418,7 +410,7 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def _get_current_server_info(self):
|
def _get_current_server_info(self):
|
||||||
"""
|
"""
|
||||||
获取服务当前资源信息
|
get server info
|
||||||
"""
|
"""
|
||||||
available_batch_size = min(self.cfg.max_prefill_batch,
|
available_batch_size = min(self.cfg.max_prefill_batch,
|
||||||
self.engine.available_batch())
|
self.engine.available_batch())
|
||||||
@@ -436,12 +428,12 @@ class TritonServer(object):
|
|||||||
|
|
||||||
def _send_result(result_dict, sender, end_flag=0):
|
def _send_result(result_dict, sender, end_flag=0):
|
||||||
"""
|
"""
|
||||||
向推理引擎发送推理结果。
|
Send inference result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result_dict (dict): 推理结果,以字典形式存储。
|
result_dict (dict): result of inference
|
||||||
sender (grpc.aio.ServerReaderWriter): gRPC的ServerReaderWriter对象,用于发送推理结果。
|
sender (grpc.aio.ServerReaderWriter): gRPC ServerReaderWriter object.
|
||||||
end_flag (int, optional): 标志位,用于标识是否发送结束信号。默认为0。
|
end_flag (int, optional): flag of end. Defaults to 0.
|
||||||
"""
|
"""
|
||||||
response = None
|
response = None
|
||||||
if result_dict:
|
if result_dict:
|
||||||
@@ -455,12 +447,13 @@ def _send_result(result_dict, sender, end_flag=0):
|
|||||||
|
|
||||||
def _send_error(error_msg, sender, error_code=200, req_id=None):
|
def _send_error(error_msg, sender, error_code=200, req_id=None):
|
||||||
"""
|
"""
|
||||||
向发送方发送错误信息
|
Send error inference result
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
error_msg (str): 错误信息
|
error_msg (str): error message
|
||||||
sender (str): 发送方标识
|
sender (grpc.aio.ServerReaderWriter): gRPC ServerReaderWriter object.
|
||||||
error_code (int, optional): 错误码. Defaults to 200.
|
error_code (int, optional): error code. Defaults to 200.
|
||||||
|
req_id (str, optional): request id. Defaults to None
|
||||||
"""
|
"""
|
||||||
if not isinstance(error_msg, str):
|
if not isinstance(error_msg, str):
|
||||||
error_msg = str(error_msg)
|
error_msg = str(error_msg)
|
||||||
|
@@ -29,14 +29,15 @@ from server.engine.config import Config
|
|||||||
from server.utils import get_logger
|
from server.utils import get_logger
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
logger = get_logger("health_checker", "health_checker.log")
|
|
||||||
env_config = Config()
|
env_config = Config()
|
||||||
|
logger = get_logger("health_checker", "health_checker.log")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v2/health/ready")
|
@app.get("/v2/health/ready")
|
||||||
def check_health():
|
def check_health():
|
||||||
"""
|
"""
|
||||||
探活接口"""
|
health check interface
|
||||||
|
"""
|
||||||
status, error_info = check()
|
status, error_info = check()
|
||||||
if status is True:
|
if status is True:
|
||||||
logger.info("check_health: OK")
|
logger.info("check_health: OK")
|
||||||
@@ -51,7 +52,8 @@ def check_health():
|
|||||||
@app.get("/v2/health/live")
|
@app.get("/v2/health/live")
|
||||||
def check_live():
|
def check_live():
|
||||||
"""
|
"""
|
||||||
探活接口"""
|
health check interface
|
||||||
|
"""
|
||||||
status, error_info = check()
|
status, error_info = check()
|
||||||
if status is True:
|
if status is True:
|
||||||
logger.info("check_health: OK")
|
logger.info("check_health: OK")
|
||||||
@@ -64,24 +66,32 @@ def check_live():
|
|||||||
|
|
||||||
|
|
||||||
def check_infer_engine_process():
|
def check_infer_engine_process():
|
||||||
# 检查infer进程是否存在
|
"""
|
||||||
|
check if infer process is alive
|
||||||
|
|
||||||
|
return:
|
||||||
|
status: bool, True if process is alive else False
|
||||||
|
"""
|
||||||
mp_num = int(env_config.mp_num)
|
mp_num = int(env_config.mp_num)
|
||||||
for i in range(mp_num):
|
for i in range(mp_num):
|
||||||
try:
|
try:
|
||||||
infer_live_flag_shm = shared_memory.SharedMemory(name=env_config.get_unique_name("shm_flag_infer_{}_live".format(i)))
|
infer_live_flag_shm = shared_memory.SharedMemory(name=env_config.get_unique_name("shm_flag_infer_{}_live".format(i)))
|
||||||
except Exception as e: # infer掉了会报异常
|
except Exception as e:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check():
|
def check():
|
||||||
"""
|
"""
|
||||||
推理服务的状态探活接口
|
State detection interface for inference services
|
||||||
|
|
||||||
|
return:
|
||||||
|
status: bool, True if process is alive else False
|
||||||
"""
|
"""
|
||||||
error_info = {}
|
error_info = {}
|
||||||
grpc_port = os.getenv("GRPC_PORT")
|
grpc_port = os.getenv("GRPC_PORT")
|
||||||
|
|
||||||
# 1. 检查server是否健康
|
# 1. check server is ready
|
||||||
if grpc_port is not None:
|
if grpc_port is not None:
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
try:
|
try:
|
||||||
@@ -94,7 +104,7 @@ def check():
|
|||||||
finally:
|
finally:
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
# 2. 检查engine是否健康
|
# 2.check engine is ready
|
||||||
is_engine_live = check_infer_engine_process()
|
is_engine_live = check_infer_engine_process()
|
||||||
if is_engine_live is False:
|
if is_engine_live is False:
|
||||||
error_info["error_code"] = 2
|
error_info["error_code"] = 2
|
||||||
@@ -102,16 +112,15 @@ def check():
|
|||||||
logger.info("infer engine is down")
|
logger.info("infer engine is down")
|
||||||
return False, error_info
|
return False, error_info
|
||||||
|
|
||||||
# 检查是否启动
|
|
||||||
engine_ready_checker = np.ndarray(engine_ready_check_flag.shape, dtype=engine_ready_check_flag.dtype,
|
engine_ready_checker = np.ndarray(engine_ready_check_flag.shape, dtype=engine_ready_check_flag.dtype,
|
||||||
buffer=shm_engine_ready_check_flag.buf)
|
buffer=shm_engine_ready_check_flag.buf)
|
||||||
if engine_ready_checker[0] == 0: # 值为0代表没启动,值为1代表已启动
|
if engine_ready_checker[0] == 0:
|
||||||
error_info["error_code"] = 2
|
error_info["error_code"] = 2
|
||||||
error_info["error_msg"] = "infer engine is down"
|
error_info["error_msg"] = "infer engine is down"
|
||||||
logger.info("infer engine is down")
|
logger.info("infer engine is down")
|
||||||
return False, error_info
|
return False, error_info
|
||||||
|
|
||||||
# 检查是否hang住
|
# check engine is hang
|
||||||
engine_hang_checker = np.ndarray(engine_healthy_recorded_time.shape, dtype=engine_healthy_recorded_time.dtype,
|
engine_hang_checker = np.ndarray(engine_healthy_recorded_time.shape, dtype=engine_healthy_recorded_time.dtype,
|
||||||
buffer=shm_engine_healthy_recorded_time.buf)
|
buffer=shm_engine_healthy_recorded_time.buf)
|
||||||
elapsed_time = time.time() - engine_hang_checker[0]
|
elapsed_time = time.time() - engine_hang_checker[0]
|
||||||
@@ -132,15 +141,17 @@ def start_health_checker(http_port):
|
|||||||
uvicorn.run(app=app, host='0.0.0.0', port=http_port, workers=1, log_level="info")
|
uvicorn.run(app=app, host='0.0.0.0', port=http_port, workers=1, log_level="info")
|
||||||
|
|
||||||
|
|
||||||
time_interval_threashold = env_config.check_health_interval # 10s infer engine没有执行过while循环,则判定hang死或挂掉等问题
|
# if infer engine not update for more than 10 seconds,consider it as hang or dead
|
||||||
|
time_interval_threashold = env_config.check_health_interval
|
||||||
engine_healthy_recorded_time = np.zeros([1], dtype=float)
|
engine_healthy_recorded_time = np.zeros([1], dtype=float)
|
||||||
|
|
||||||
shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
|
shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
|
||||||
create=True,
|
create=True,
|
||||||
size=engine_healthy_recorded_time.nbytes,
|
size=engine_healthy_recorded_time.nbytes,
|
||||||
name=env_config.get_unique_name("engine_healthy_recorded_time")) # 由推理引擎进行更新,每次读token时候就刷新一次时间,正常情况下该时间戳在30s内肯定会被刷新
|
name=env_config.get_unique_name("engine_healthy_recorded_time"))
|
||||||
|
|
||||||
engine_ready_check_flag = np.zeros([1], dtype=np.int32)
|
engine_ready_check_flag = np.zeros([1], dtype=np.int32)
|
||||||
shm_engine_ready_check_flag = shared_memory.SharedMemory(
|
shm_engine_ready_check_flag = shared_memory.SharedMemory(
|
||||||
create=True,
|
create=True,
|
||||||
size=engine_ready_check_flag.nbytes,
|
size=engine_ready_check_flag.nbytes,
|
||||||
name=env_config.get_unique_name("engine_ready_check_flag")) # 由推理引擎更新,推理引擎初始化完毕时候置为1
|
name=env_config.get_unique_name("engine_ready_check_flag"))
|
||||||
|
@@ -18,19 +18,17 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from logging.handlers import BaseRotatingHandler
|
from logging.handlers import BaseRotatingHandler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import subprocess
|
|
||||||
|
|
||||||
|
|
||||||
class DailyRotatingFileHandler(BaseRotatingHandler):
|
class DailyRotatingFileHandler(BaseRotatingHandler):
|
||||||
"""
|
"""
|
||||||
- 可以支持多进程
|
like `logging.TimedRotatingFileHandler`, but this class support multi-process
|
||||||
- 只支持自然日分割
|
|
||||||
- 暂不支持UTC
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -53,7 +51,7 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
|||||||
|
|
||||||
def shouldRollover(self, record):
|
def shouldRollover(self, record):
|
||||||
"""
|
"""
|
||||||
判断是否该滚动日志,如果当前时间对应的日志文件名与当前打开的日志文件名不一致,则需要滚动日志
|
check scroll through the log
|
||||||
"""
|
"""
|
||||||
if self.current_filename != self._compute_fn():
|
if self.current_filename != self._compute_fn():
|
||||||
return True
|
return True
|
||||||
@@ -61,7 +59,7 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
|||||||
|
|
||||||
def doRollover(self):
|
def doRollover(self):
|
||||||
"""
|
"""
|
||||||
滚动日志
|
scroll log
|
||||||
"""
|
"""
|
||||||
if self.stream:
|
if self.stream:
|
||||||
self.stream.close()
|
self.stream.close()
|
||||||
@@ -77,20 +75,19 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
|||||||
|
|
||||||
def _compute_fn(self):
|
def _compute_fn(self):
|
||||||
"""
|
"""
|
||||||
计算当前时间对应的日志文件名
|
Calculate the log file name corresponding current time
|
||||||
"""
|
"""
|
||||||
return self.base_filename + "." + time.strftime(self.suffix, time.localtime())
|
return self.base_filename + "." + time.strftime(self.suffix, time.localtime())
|
||||||
|
|
||||||
def _open(self):
|
def _open(self):
|
||||||
"""
|
"""
|
||||||
打开新的日志文件,同时更新base_filename指向的软链,修改软链不会对日志记录产生任何影响
|
open new log file
|
||||||
"""
|
"""
|
||||||
if self.encoding is None:
|
if self.encoding is None:
|
||||||
stream = open(str(self.current_log_path), self.mode)
|
stream = open(str(self.current_log_path), self.mode)
|
||||||
else:
|
else:
|
||||||
stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)
|
stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)
|
||||||
|
|
||||||
# 删除旧的软链
|
|
||||||
if self.base_log_path.exists():
|
if self.base_log_path.exists():
|
||||||
try:
|
try:
|
||||||
if (
|
if (
|
||||||
@@ -109,7 +106,7 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
|||||||
|
|
||||||
def delete_expired_files(self):
|
def delete_expired_files(self):
|
||||||
"""
|
"""
|
||||||
删除过期的日志
|
delete expired log files
|
||||||
"""
|
"""
|
||||||
if self.backup_count <= 0:
|
if self.backup_count <= 0:
|
||||||
return
|
return
|
||||||
@@ -135,7 +132,7 @@ class DailyRotatingFileHandler(BaseRotatingHandler):
|
|||||||
|
|
||||||
def get_logger(name, file_name, without_formater=False):
|
def get_logger(name, file_name, without_formater=False):
|
||||||
"""
|
"""
|
||||||
获取logger
|
get logger
|
||||||
"""
|
"""
|
||||||
log_dir = os.getenv("FD_LOG_DIR", default="log")
|
log_dir = os.getenv("FD_LOG_DIR", default="log")
|
||||||
is_debug = int(os.getenv("FD_DEBUG", default=0))
|
is_debug = int(os.getenv("FD_DEBUG", default=0))
|
||||||
@@ -158,16 +155,11 @@ def get_logger(name, file_name, without_formater=False):
|
|||||||
handler.propagate = False
|
handler.propagate = False
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
# 实例化单例logger
|
|
||||||
model_server_logger = get_logger("model_server", "infer_server.log")
|
|
||||||
http_server_logger = get_logger("http_server", "http_server.log")
|
|
||||||
data_processor_logger = get_logger("data_processor", "data_processor.log")
|
|
||||||
monitor_logger = get_logger("monitor_logger", "monitor_logger.log", True)
|
|
||||||
error_logger = get_logger("error_logger", "error_logger.log", True)
|
|
||||||
|
|
||||||
|
|
||||||
def str_to_datetime(date_string):
|
def str_to_datetime(date_string):
|
||||||
"""datetime字符串转datetime对象"""
|
"""
|
||||||
|
string to datetime class object
|
||||||
|
"""
|
||||||
if "." in date_string:
|
if "." in date_string:
|
||||||
return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
|
return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
|
||||||
else:
|
else:
|
||||||
@@ -176,14 +168,14 @@ def str_to_datetime(date_string):
|
|||||||
|
|
||||||
def datetime_diff(datetime_start, datetime_end):
|
def datetime_diff(datetime_start, datetime_end):
|
||||||
"""
|
"""
|
||||||
计算两个日期时间之间的差值(以秒为单位)。
|
Calculate the difference between two dates and times(s)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
datetime_start (Union[str, datetime.datetime]): 开始时间,可以是字符串或datetime.datetime对象。
|
datetime_start (Union[str, datetime.datetime]): start time
|
||||||
datetime_end (Union[str, datetime.datetime]): 结束时间,可以是字符串或datetime.datetime对象。
|
datetime_end (Union[str, datetime.datetime]): end time
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
float: 日期时间差值,以秒为单位。
|
float: date time difference(s)
|
||||||
"""
|
"""
|
||||||
if isinstance(datetime_start, str):
|
if isinstance(datetime_start, str):
|
||||||
datetime_start = str_to_datetime(datetime_start)
|
datetime_start = str_to_datetime(datetime_start)
|
||||||
@@ -194,3 +186,10 @@ def datetime_diff(datetime_start, datetime_end):
|
|||||||
else:
|
else:
|
||||||
cost = datetime_start - datetime_end
|
cost = datetime_start - datetime_end
|
||||||
return cost.total_seconds()
|
return cost.total_seconds()
|
||||||
|
|
||||||
|
|
||||||
|
model_server_logger = get_logger("model_server", "infer_server.log")
|
||||||
|
http_server_logger = get_logger("http_server", "http_server.log")
|
||||||
|
data_processor_logger = get_logger("data_processor", "data_processor.log")
|
||||||
|
monitor_logger = get_logger("monitor_logger", "monitor_logger.log", True)
|
||||||
|
error_logger = get_logger("error_logger", "error_logger.log", True)
|
||||||
|
Reference in New Issue
Block a user