Files
FastDeploy/tests/e2e/test_ernie_03b_pd_splitwise_scheduler.py
2025-12-04 13:19:02 +08:00

428 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Test splitwise deployment which uses splitwise_scheduler,
# and ENABLE_V1_KVCACHE_SCHEDULER is 0
import json
import os
import shutil
import signal
import subprocess
import sys
import time
import pytest
import requests
from utils.serving_utils import (
FD_API_PORT,
FD_CACHE_QUEUE_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
clean,
is_port_open,
)
# Read ports from environment variables; use default values if not set
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
FD_REDIS_PORT = int(os.getenv("FD_REDIS_PORT", 8533))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_CONNECTOR_PORT,
FD_API_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_CONNECTOR_PORT + 1,
FD_REDIS_PORT,
]
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean(PORTS_TO_CLEAN)
print("log dir clean ")
if os.path.exists("log_redis") and os.path.isdir("log_redis"):
shutil.rmtree("log_redis")
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
shutil.rmtree("log_prefill")
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
shutil.rmtree("log_decode")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
# redis-server
print("start redis...")
env_copy = os.environ.copy()
log_path = "router.log"
cmd = [
"redis-server",
"--port",
str(FD_REDIS_PORT),
"--daemonize",
"yes",
]
with open(log_path, "w") as logfile:
process_redis = subprocess.Popen(
cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_copy,
)
# prefill实例
print("start prefill...")
env_prefill = os.environ.copy()
env_prefill["CUDA_VISIBLE_DEVICES"] = "0"
env_prefill["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_prefill["FD_LOG_DIR"] = "log_prefill"
prefill_log_path = "server_prefill.log"
prefill_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--splitwise-role",
"prefill",
"--cache-transfer-protocol",
"ipc",
"--pd-comm-port",
str(FD_CONNECTOR_PORT),
"--scheduler-name",
"splitwise",
"--scheduler-host",
"127.0.0.1",
"--scheduler-port",
str(FD_REDIS_PORT),
]
# Start subprocess in new process group
with open(prefill_log_path, "w") as logfile:
process_prefill = subprocess.Popen(
prefill_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_prefill,
)
time.sleep(1)
# decode实例
print("start decode...")
env_decode = os.environ.copy()
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
env_decode["ENABLE_V1_KVCACHE_SCHEDULER"] = "0"
env_decode["FD_LOG_DIR"] = "log_decode"
decode_log_path = "server_decode.log"
decode_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT + 1),
"--tensor-parallel-size",
"1",
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT + 1),
"--metrics-port",
str(FD_METRICS_PORT + 1),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT + 1),
"--max-model-len",
"8192",
"--max-num-seqs",
"20",
"--quantization",
"wint8",
"--splitwise-role",
"decode",
"--cache-transfer-protocol",
"ipc",
"--pd-comm-port",
str(FD_CONNECTOR_PORT + 1),
"--scheduler-name",
"splitwise",
"--scheduler-host",
"127.0.0.1",
"--scheduler-port",
str(FD_REDIS_PORT),
]
# Start subprocess in new process group
with open(decode_log_path, "w") as logfile:
process_decode = subprocess.Popen(
decode_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_decode,
)
# Wait up to 300 seconds for API server to be ready
for _ in range(60):
if is_port_open("127.0.0.1", FD_API_PORT) and is_port_open("127.0.0.1", FD_API_PORT + 1):
print(f"Prefill server is up on port {FD_API_PORT}")
print(f"Decode server is up on port {FD_API_PORT + 1}")
break
time.sleep(5)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process_redis.pid, signal.SIGTERM)
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
print(f"Prefill server (pid={process_prefill.pid}) terminated")
print(f"Decode server (pid={process_decode.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions", f"http://0.0.0.0:{FD_API_PORT+1}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def test_metrics_config(metrics_url):
timeout = 600
url = metrics_url.replace("metrics", "config-info")
res = requests.get(url, timeout=timeout)
assert res.status_code == 200
def send_request(url, payload, timeout=60):
"""
发送请求到指定的URL并返回响应结果。
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ 请求失败:{e}")
return None
def get_stream_chunks(response):
"""解析流式返回生成chunk List[dict]"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"解析失败: {e}, 行内容: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
print("返回内容:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""测试流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
response = send_request(url=p_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""测试非流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
response = send_request(url=p_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""测试流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
p_url = p_url.replace("chat/completions", "completions")
response = send_request(url=p_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
# print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""测试非流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
p_url, d_url = api_url
p_url = p_url.replace("chat/completions", "completions")
response = send_request(url=p_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
# print("Decode Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"