mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
[CE] Add base test class for web server testing (#3120)
* add test base class * fix codestyle * fix codestyle
This commit is contained in:
34
test/ce/server/core/__init__.py
Normal file
34
test/ce/server/core/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @author DDDivano
|
||||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .logger import Logger
|
||||
|
||||
base_logger = Logger(loggername="FDSentry", save_level="channel", log_path="./fd_logs").get_logger()
|
||||
base_logger.setLevel("INFO")
|
||||
from .request_template import TEMPLATES
|
||||
from .utils import build_request_payload, send_request
|
||||
|
||||
__all__ = ["build_request_payload", "send_request", "TEMPLATES"]
|
||||
|
||||
# 检查环境变量是否存在
|
||||
URL = os.environ.get("URL")
|
||||
TEMPLATE = os.environ.get("TEMPLATE")
|
||||
|
||||
missing_vars = []
|
||||
if not URL:
|
||||
missing_vars.append("URL")
|
||||
if not TEMPLATE:
|
||||
missing_vars.append("TEMPLATE")
|
||||
|
||||
if missing_vars:
|
||||
msg = (
|
||||
f"❌ 缺少环境变量:{', '.join(missing_vars)},请先设置,例如:\n"
|
||||
f" export URL=http://localhost:8000/v1/chat/completions\n"
|
||||
f" export TEMPLATE=TOKEN_LOGPROB"
|
||||
)
|
||||
base_logger.error(msg)
|
||||
sys.exit(1) # 终止程序
|
99
test/ce/server/core/logger.py
Normal file
99
test/ce/server/core/logger.py
Normal file
@@ -0,0 +1,99 @@
|
||||
#!/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @author DDDivano
|
||||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||
"""
|
||||
ServeTest
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""
|
||||
日志记录配置的基础类。
|
||||
"""
|
||||
|
||||
SAVE_LEVELS = ["both", "file", "channel"]
|
||||
LOG_FORMAT = "%(asctime)s - %(name)s - [%(levelname)s] - %(message)s"
|
||||
|
||||
def __init__(self, loggername, save_level="both", log_path=None):
|
||||
"""
|
||||
使用指定名称和保存级别初始化日志记录器。
|
||||
|
||||
Args:
|
||||
loggername (str): 日志记录器的名称。
|
||||
save_level (str): 日志保存的级别。默认为"both"。file: 仅保存到文件,channel: 仅保存到控制台。
|
||||
log_path (str, optional): 日志文件保存路径。默认为None。
|
||||
"""
|
||||
|
||||
if save_level not in self.SAVE_LEVELS:
|
||||
raise ValueError(f"Invalid save level: {save_level}. Allowed values: {self.SAVE_LEVELS}")
|
||||
|
||||
self.logger = logging.getLogger(loggername)
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
|
||||
# 设置时区为东八区
|
||||
tz = pytz.timezone("Asia/Shanghai")
|
||||
|
||||
# 自定义时间格式化器,指定时区为东八区
|
||||
class CSTFormatter(logging.Formatter):
|
||||
"""
|
||||
自定义时间格式化器,指定时区为东八区
|
||||
"""
|
||||
|
||||
def converter(self, timestamp):
|
||||
"""
|
||||
自定义时间转换函数,加上时区信息
|
||||
Args:
|
||||
timestamp (int): 时间戳。
|
||||
Returns:
|
||||
tuple: 格式化后的时间元组。
|
||||
"""
|
||||
dt = datetime.utcfromtimestamp(timestamp)
|
||||
dt = pytz.utc.localize(dt).astimezone(tz)
|
||||
return dt.timetuple()
|
||||
|
||||
formatter = CSTFormatter(self.LOG_FORMAT)
|
||||
log_name = None
|
||||
if save_level == "both" or save_level == "file":
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
log_filename = f"out_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log"
|
||||
log_name = os.path.join(log_path, log_filename)
|
||||
file_handler = logging.FileHandler(log_name, encoding="utf-8")
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
if save_level == "both" or save_level == "channel":
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
if log_name is None:
|
||||
self.logger.info(
|
||||
f"Logger initialized. Log level: {save_level}. "
|
||||
f"Log path ({log_path}) is unused according to the level."
|
||||
)
|
||||
else:
|
||||
self.logger.info(f"Logger initialized. Log level: {save_level}. Log path: {log_name}")
|
||||
# Adjusting the timezone offset
|
||||
|
||||
def get_logger(self):
|
||||
"""
|
||||
Get the logger object
|
||||
"""
|
||||
return self.logger
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the logger
|
||||
logger = Logger("test_logger", save_level="channel").get_logger()
|
||||
logger.info("the is the beginning")
|
||||
logger.debug("the is the beginning")
|
||||
logger.warning("the is the beginning")
|
||||
logger.error("the is the beginning")
|
25
test/ce/server/core/request_template.py
Normal file
25
test/ce/server/core/request_template.py
Normal file
@@ -0,0 +1,25 @@
|
||||
#!/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @author DDDivano
|
||||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||
"""
|
||||
ServeTest
|
||||
"""
|
||||
|
||||
|
||||
TOKEN_LOGPROB = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"stream": True,
|
||||
"logprobs": True,
|
||||
"top_logprobs": 5,
|
||||
"max_tokens": 10000,
|
||||
}
|
||||
|
||||
|
||||
TEMPLATES = {
|
||||
"TOKEN_LOGPROB": TOKEN_LOGPROB,
|
||||
# "ANOTHER_TEMPLATE": ANOTHER_TEMPLATE
|
||||
}
|
57
test/ce/server/core/utils.py
Normal file
57
test/ce/server/core/utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#!/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @author DDDivano
|
||||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||
|
||||
import requests
|
||||
from core import TEMPLATES, base_logger
|
||||
|
||||
|
||||
def build_request_payload(template_name: str, case_data: dict) -> dict:
|
||||
"""
|
||||
基于模板构造请求 payload,按优先级依次合并:
|
||||
template < payload 参数 < case_data,后者会覆盖前者的同名字段。
|
||||
|
||||
:param template_name: 模板变量名,例如 "TOKEN_LOGPROB"
|
||||
:return: 构造后的完整请求 payload dict
|
||||
"""
|
||||
template = TEMPLATES[template_name]
|
||||
print(template)
|
||||
final_payload = template.copy()
|
||||
final_payload.update(case_data)
|
||||
|
||||
return final_payload
|
||||
|
||||
|
||||
def send_request(url, payload, timeout=600, stream=False):
|
||||
"""
|
||||
向指定URL发送POST请求,并返回响应结果。
|
||||
|
||||
Args:
|
||||
url (str): 请求的目标URL。
|
||||
payload (dict): 请求的负载数据,应该是一个字典类型。
|
||||
timeout (int, optional): 请求的超时时间,默认为600秒。
|
||||
stream (bool, optional): 是否以流的方式下载响应内容,默认为False。
|
||||
|
||||
Returns:
|
||||
response: 请求的响应结果,如果请求失败则返回None。
|
||||
|
||||
Raises:
|
||||
None
|
||||
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
base_logger.info("🔄 正在请求模型接口...")
|
||||
|
||||
try:
|
||||
res = requests.post(url, headers=headers, json=payload, stream=stream, timeout=timeout)
|
||||
base_logger.info("🟢 接收响应中...\n")
|
||||
return res
|
||||
except requests.exceptions.Timeout:
|
||||
base_logger.error(f"❌ 请求超时(超过 {timeout} 秒)")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
base_logger.error(f"❌ 请求失败:{e}")
|
||||
return None
|
48
test/ce/server/demo.py
Normal file
48
test/ce/server/demo.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# @author DDDivano
|
||||
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
||||
|
||||
from core import TEMPLATE, URL, build_request_payload, send_request
|
||||
|
||||
|
||||
def demo():
|
||||
data = {
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 3,
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
req = send_request(URL, payload)
|
||||
print(req.json())
|
||||
req = req.json()
|
||||
|
||||
assert req["usage"]["prompt_tokens"] == 22
|
||||
assert req["usage"]["total_tokens"] == 25
|
||||
assert req["usage"]["completion_tokens"] == 3
|
||||
|
||||
|
||||
def test_demo():
|
||||
data = {
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 3,
|
||||
}
|
||||
payload = build_request_payload(TEMPLATE, data)
|
||||
req = send_request(URL, payload)
|
||||
print(req.json())
|
||||
req = req.json()
|
||||
|
||||
assert req["usage"]["prompt_tokens"] == 22
|
||||
assert req["usage"]["total_tokens"] == 25
|
||||
assert req["usage"]["completion_tokens"] == 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo()
|
Reference in New Issue
Block a user