mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00

* add stable ci * fix * update * fix * rename tests dir;fix stable ci bug * add timeout limit * update
141 lines
4.5 KiB
Python
141 lines
4.5 KiB
Python
#!/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
# @author: DDDivano
|
|
# encoding=utf-8 vi:ts=4:sw=4:expandtab:ft=python
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import time
|
|
from collections import Counter
|
|
from statistics import mean, median
|
|
|
|
import aiohttp
|
|
from tqdm import tqdm
|
|
|
|
# ============ 配置 ============
|
|
API_URL = os.environ.get("URL", "http://localhost:8000/v1/chat/completions")
|
|
MAX_CONCURRENCY = 200 # 最大并发协程数
|
|
TOTAL_REQUESTS = 300000 # 总请求数
|
|
TIMEOUT = 1800 # 每个请求超时时间(秒)
|
|
DATA_FILE = "math_15k.jsonl" # 请求数据文件
|
|
|
|
|
|
# ============ 数据加载 ============
|
|
async def load_data():
|
|
data = []
|
|
with open(DATA_FILE, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
try:
|
|
obj = json.loads(line)
|
|
# RL 要求
|
|
data.append(obj["src"][0] if "src" in obj else obj.get("content", line.strip()))
|
|
except json.JSONDecodeError:
|
|
data.append(line.strip())
|
|
if not data:
|
|
raise ValueError(f"{DATA_FILE} 为空或格式不正确")
|
|
return data
|
|
|
|
|
|
# ============ 请求发送 ============
|
|
async def send_request(session, payload):
|
|
start_time = time.perf_counter()
|
|
try:
|
|
async with session.post(API_URL, json=payload) as resp:
|
|
try:
|
|
_ = await resp.json()
|
|
except Exception:
|
|
_ = await resp.text()
|
|
latency = time.perf_counter() - start_time
|
|
return resp.status == 200, latency, resp.status, None if resp.status == 200 else _
|
|
except Exception as e:
|
|
latency = time.perf_counter() - start_time
|
|
return False, latency, None, f"{type(e).__name__}: {e}"
|
|
|
|
|
|
# ============ Worker ============
|
|
async def worker(name, session, prompts, counter, latencies, pbar, queue):
|
|
while True:
|
|
i = await queue.get()
|
|
if i is None: # 毒丸退出
|
|
queue.task_done()
|
|
break
|
|
|
|
payload = {
|
|
"model": "eb",
|
|
"messages": [{"role": "user", "content": prompts[i % len(prompts)]}],
|
|
"max_prompt_len": 2048,
|
|
"max_dec_len": 1024,
|
|
"min_dec_len": 32,
|
|
"top_p": 1.0,
|
|
"temperature": 1.0,
|
|
"repetition_penalty": 1.0,
|
|
"rollout_quant_type": "weight_only_int8",
|
|
"disable_chat_template": True,
|
|
}
|
|
|
|
success, latency, status, error = await send_request(session, payload)
|
|
if success:
|
|
counter["success"] += 1
|
|
latencies.append(latency)
|
|
else:
|
|
# print(f"Request failed ({status}): {error}")
|
|
counter["fail"] += 1
|
|
counter[f"error_{error or 'client'}"] += 1
|
|
|
|
pbar.update(1)
|
|
queue.task_done()
|
|
|
|
|
|
# ============ 主流程 ============
|
|
async def run_load_test():
|
|
prompts = await load_data()
|
|
queue = asyncio.Queue(maxsize=MAX_CONCURRENCY * 5) # 限制队列大小,降低内存占用
|
|
counter = Counter()
|
|
latencies = []
|
|
|
|
connector = aiohttp.TCPConnector(limit=MAX_CONCURRENCY * 2) # 限制TCP连接
|
|
timeout = aiohttp.ClientTimeout(total=TIMEOUT)
|
|
|
|
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
|
|
with tqdm(total=TOTAL_REQUESTS, desc="压测进度") as pbar:
|
|
# 启动 Worker
|
|
workers = [
|
|
asyncio.create_task(worker(f"W{i}", session, prompts, counter, latencies, pbar, queue))
|
|
for i in range(MAX_CONCURRENCY)
|
|
]
|
|
|
|
# 边生产边消费
|
|
for i in range(TOTAL_REQUESTS):
|
|
await queue.put(i)
|
|
|
|
# 发送毒丸让 worker 退出
|
|
for _ in workers:
|
|
await queue.put(None)
|
|
|
|
await queue.join()
|
|
await asyncio.gather(*workers)
|
|
|
|
generate_report(counter, latencies)
|
|
|
|
|
|
# ============ 报告输出 ============
|
|
def generate_report(counter, latencies):
|
|
print("\n====== 压测报告 ======")
|
|
total = counter["success"] + counter["fail"]
|
|
print(f"总请求数: {total}")
|
|
print(f"成功数: {counter['success']}")
|
|
print(f"失败数: {counter['fail']}")
|
|
for k, v in counter.items():
|
|
if k.startswith("error_"):
|
|
print(f"{k}: {v}")
|
|
if latencies:
|
|
print(f"平均延迟: {mean(latencies):.4f}s")
|
|
print(f"中位延迟: {median(latencies):.4f}s")
|
|
print(f"最快: {min(latencies):.4f}s")
|
|
print(f"最慢: {max(latencies):.4f}s")
|
|
print("=====================")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(run_load_test())
|