Files
FastDeploy/tests/e2e/test_ernie_21b_mtp.py
Zhang Yulong 94be5ebdd1 [CI] Add CI case for MTP accept ratio (#5570)
* Implement test for MTP accept ratio

Add test for MTP accept ratio with assertions on results and metrics.

* Update test_ernie_21b_mtp.py

* Refactor test_mtp_accept_ratio for baseline comparison

Refactor test_mtp_accept_ratio to compare results against baseline file and metrics.

* Fix formatting issues in test_ernie_21b_mtp.py

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
2025-12-17 21:35:02 +08:00

368 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import signal
import subprocess
import sys
import time
import pytest
import requests
from utils.serving_utils import (
FD_API_PORT,
FD_CACHE_QUEUE_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
clean,
is_port_open,
)
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean()
print("log dir clean ")
if os.path.exists("log") and os.path.isdir("log"):
shutil.rmtree("log")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle")
else:
model_path = "./ernie-4_5-21b-a3b-bf16-paddle"
mtp_model_path = os.path.join(model_path, "mtp")
speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path}
log_path = "server.log"
cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"2",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"32768",
"--max-num-seqs",
"128",
"--quantization",
"wint4",
"--speculative-config",
json.dumps(speculative_config),
"--graph-optimization-config",
'{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}',
]
# Start subprocess in new process group
# 清除log目录
if os.path.exists("log"):
shutil.rmtree("log")
with open(log_path, "w") as logfile:
process = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
)
# Wait up to 300 seconds for API server to be ready
for _ in range(300):
if is_port_open("127.0.0.1", FD_API_PORT):
print(f"Server is up on port {FD_API_PORT}")
break
time.sleep(1)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process.pid, signal.SIGTERM)
clean()
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process.pid, signal.SIGTERM)
clean()
print(f"server (pid={process.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def send_request(url, payload, timeout=60):
"""
发送请求到指定的URL并返回响应结果。
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ 请求失败:{e}")
return None
def get_stream_chunks(response):
"""解析流式返回生成chunk List[dict]"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"解析失败: {e}, 行内容: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
print("返回内容:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""测试流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Prefill Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""测试非流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""测试流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
# print("Prefill Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""测试非流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
# print("Prefill Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_mtp_accept_ratio(api_url):
"""测试mtp接受率"""
payload = {
"model": "default",
"messages": [
{
"role": "user",
"content": "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、"
"公共管理等多领域。在理论层面COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性"
"指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估"
"则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等并衍生出风险共享、"
"升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展,"
"人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理"
"的文献研究综述300字以内",
},
],
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"temperature": 0,
"seed": 23,
"top_p": 0,
}
print("fastdeploy answer is :")
try:
# TODO: 第一次和第二次存在diff后面正常暂时多请求一次
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
for idx, chunk in enumerate(chunks):
print(f"\nchunk[{idx}]:\n{json.dumps(chunk, ensure_ascii=False)}")
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
speculate_metrics = chunks[-2]["choices"][0]["speculate_metrics"]
except Exception as e:
print(f"解析失败: {e}")
print("\nresult:\n", result)
base_path = os.getenv("MODEL_PATH")
baseline_path = os.path.join(base_path, "21b_mtp_accept_ratio_baseline.txt")
with open(baseline_path, "r", encoding="utf-8") as f:
baseline = f.read()
baseline_ratio = {
"accepted_tokens": 131,
"rejected_tokens": 23,
"accept_ratio": 0.4122137404580153,
"average_accept_length": 1.7012987012987013,
"accept_ratio_per_head": [0.7012987012987013],
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result_2 = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
speculate_metrics_2 = chunks[-2]["choices"][0]["speculate_metrics"]
print("chunks:", chunks[-2])
print("baseline", speculate_metrics)
print("speculate_metrics_2", speculate_metrics_2)
assert result_2 == baseline, f"与baseline存在diffresult_2: {result}\n baseline: {baseline}"
assert speculate_metrics_2 == baseline_ratio, (
f"speculate_metrics存在diff" f"speculate_metrics_2: {speculate_metrics_2}\n " f"baseline: {baseline_ratio}"
)
assert speculate_metrics_2["accept_ratio"] > 0, "accept_ratio异常"
prompt_tokens = chunks[-1]["usage"]["prompt_tokens"]
cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"]
assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题"