[CE] Add base test class for web server testing (#3120)

* add test base class

* fix codestyle

* fix codestyle
This commit is contained in:
Divano
2025-07-31 23:28:50 +08:00
committed by GitHub
parent e1011e92d9
commit 1d93565082
5 changed files with 263 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
import os
import sys
from .logger import Logger
base_logger = Logger(loggername="FDSentry", save_level="channel", log_path="./fd_logs").get_logger()
base_logger.setLevel("INFO")
from .request_template import TEMPLATES
from .utils import build_request_payload, send_request
__all__ = ["build_request_payload", "send_request", "TEMPLATES"]
# 检查环境变量是否存在
URL = os.environ.get("URL")
TEMPLATE = os.environ.get("TEMPLATE")
missing_vars = []
if not URL:
missing_vars.append("URL")
if not TEMPLATE:
missing_vars.append("TEMPLATE")
if missing_vars:
msg = (
f"❌ 缺少环境变量:{', '.join(missing_vars)},请先设置,例如:\n"
f" export URL=http://localhost:8000/v1/chat/completions\n"
f" export TEMPLATE=TOKEN_LOGPROB"
)
base_logger.error(msg)
sys.exit(1) # 终止程序

View File

@@ -0,0 +1,99 @@
#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
"""
ServeTest
"""
import logging
import os
from datetime import datetime
import pytz
class Logger(object):
"""
日志记录配置的基础类。
"""
SAVE_LEVELS = ["both", "file", "channel"]
LOG_FORMAT = "%(asctime)s - %(name)s - [%(levelname)s] - %(message)s"
def __init__(self, loggername, save_level="both", log_path=None):
"""
使用指定名称和保存级别初始化日志记录器。
Args:
loggername (str): 日志记录器的名称。
save_level (str): 日志保存的级别。默认为"both"。file: 仅保存到文件channel: 仅保存到控制台。
log_path (str, optional): 日志文件保存路径。默认为None。
"""
if save_level not in self.SAVE_LEVELS:
raise ValueError(f"Invalid save level: {save_level}. Allowed values: {self.SAVE_LEVELS}")
self.logger = logging.getLogger(loggername)
self.logger.setLevel(logging.DEBUG)
# 设置时区为东八区
tz = pytz.timezone("Asia/Shanghai")
# 自定义时间格式化器,指定时区为东八区
class CSTFormatter(logging.Formatter):
"""
自定义时间格式化器,指定时区为东八区
"""
def converter(self, timestamp):
"""
自定义时间转换函数,加上时区信息
Args:
timestamp (int): 时间戳。
Returns:
tuple: 格式化后的时间元组。
"""
dt = datetime.utcfromtimestamp(timestamp)
dt = pytz.utc.localize(dt).astimezone(tz)
return dt.timetuple()
formatter = CSTFormatter(self.LOG_FORMAT)
log_name = None
if save_level == "both" or save_level == "file":
os.makedirs(log_path, exist_ok=True)
log_filename = f"out_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log"
log_name = os.path.join(log_path, log_filename)
file_handler = logging.FileHandler(log_name, encoding="utf-8")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
self.logger.addHandler(file_handler)
if save_level == "both" or save_level == "channel":
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
if log_name is None:
self.logger.info(
f"Logger initialized. Log level: {save_level}. "
f"Log path ({log_path}) is unused according to the level."
)
else:
self.logger.info(f"Logger initialized. Log level: {save_level}. Log path: {log_name}")
# Adjusting the timezone offset
def get_logger(self):
"""
Get the logger object
"""
return self.logger
if __name__ == "__main__":
# Test the logger
logger = Logger("test_logger", save_level="channel").get_logger()
logger.info("the is the beginning")
logger.debug("the is the beginning")
logger.warning("the is the beginning")
logger.error("the is the beginning")

View File

@@ -0,0 +1,25 @@
#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
"""
ServeTest
"""
TOKEN_LOGPROB = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"stream": True,
"logprobs": True,
"top_logprobs": 5,
"max_tokens": 10000,
}
TEMPLATES = {
"TOKEN_LOGPROB": TOKEN_LOGPROB,
# "ANOTHER_TEMPLATE": ANOTHER_TEMPLATE
}

View File

@@ -0,0 +1,57 @@
#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
import requests
from core import TEMPLATES, base_logger
def build_request_payload(template_name: str, case_data: dict) -> dict:
"""
基于模板构造请求 payload按优先级依次合并
template < payload 参数 < case_data后者会覆盖前者的同名字段。
:param template_name: 模板变量名,例如 "TOKEN_LOGPROB"
:return: 构造后的完整请求 payload dict
"""
template = TEMPLATES[template_name]
print(template)
final_payload = template.copy()
final_payload.update(case_data)
return final_payload
def send_request(url, payload, timeout=600, stream=False):
"""
向指定URL发送POST请求并返回响应结果。
Args:
url (str): 请求的目标URL。
payload (dict): 请求的负载数据,应该是一个字典类型。
timeout (int, optional): 请求的超时时间默认为600秒。
stream (bool, optional): 是否以流的方式下载响应内容默认为False。
Returns:
response: 请求的响应结果如果请求失败则返回None。
Raises:
None
"""
headers = {
"Content-Type": "application/json",
}
base_logger.info("🔄 正在请求模型接口...")
try:
res = requests.post(url, headers=headers, json=payload, stream=stream, timeout=timeout)
base_logger.info("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
base_logger.error(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
base_logger.error(f"❌ 请求失败:{e}")
return None

48
test/ce/server/demo.py Normal file
View File

@@ -0,0 +1,48 @@
#!/bin/env python3
# -*- coding: utf-8 -*-
# @author DDDivano
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
from core import TEMPLATE, URL, build_request_payload, send_request
def demo():
data = {
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 3,
}
payload = build_request_payload(TEMPLATE, data)
req = send_request(URL, payload)
print(req.json())
req = req.json()
assert req["usage"]["prompt_tokens"] == 22
assert req["usage"]["total_tokens"] == 25
assert req["usage"]["completion_tokens"] == 3
def test_demo():
data = {
"stream": False,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 3,
}
payload = build_request_payload(TEMPLATE, data)
req = send_request(URL, payload)
print(req.json())
req = req.json()
assert req["usage"]["prompt_tokens"] == 22
assert req["usage"]["total_tokens"] == 25
assert req["usage"]["completion_tokens"] == 5
if __name__ == "__main__":
demo()