[PD Disaggregation] support different tp_size for prefill and decode (#5296)

* up

* up

* up

* fix
This commit is contained in:
Juncai
2025-12-01 17:50:20 +08:00
committed by GitHub
parent 54119cf07e
commit 0925d44f18
13 changed files with 584 additions and 36 deletions

View File

@@ -55,13 +55,13 @@ def parse_args():
default="mixed",
help="splitwise role, can be decode, prefill or mixed",
)
parser.add_argument("--rank", type=int, default=0, help="current rank")
parser.add_argument("--rank", type=int, default=0, help="local tp rank id")
parser.add_argument("--device_id", type=int, default=0, help="device id")
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
parser.add_argument(
"--protocol",
@@ -208,6 +208,8 @@ class CacheMessager:
max_block_num,
block_bytes,
rdma_port,
nranks,
rank,
)
self.gpu_id = gpu_id
@@ -507,6 +509,8 @@ class CacheMessagerV1:
max_block_num,
block_bytes,
rdma_port,
nranks,
rank,
)
self.gpu_id = gpu_id
@@ -595,6 +599,7 @@ class CacheMessagerV1:
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
@@ -633,13 +638,27 @@ class CacheMessagerV1:
current_transfer_protocol = task["transfer_protocol"]
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
target_id = int(task["rdma_ports"][self.rank])
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
decode_tp_size = task.get("decode_tp_size", self.nranks)
if len(task["rdma_ports"]) == self.nranks:
target_id = int(task["rdma_ports"][self.rank])
elif len(task["rdma_ports"]) == 1:
target_id = task["rdma_ports"][0]
else:
task["status"] = "the tp_size of prefill and decode is mismatch"
continue
if "error" in task["status"]:
continue
# TODO: use is connected to check if the connection is still alive
logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}")
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
logger.debug(
f"rdma, start connect decode, {target_ip}:{target_id}, "
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
)
status = self.messager[current_transfer_protocol].connect(
target_ip, target_id, decode_tp_size
)
if status:
logger.info(f"connect to {target_ip}:{target_id} success")
else:
@@ -762,12 +781,22 @@ class CacheMessagerV1:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
ip = task["ip"]
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
decode_tp_size = task.get("decode_tp_size", self.nranks)
rdma_ports = task["rdma_ports"]
rdma_ports_len = len(rdma_ports)
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
# TODO: support other cases
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:

View File

@@ -142,6 +142,7 @@ struct Connection {
int wc_target_count;
// Configuration
int decode_tp_size;
int layer_number;
int block_number;
int block_byte_size;

View File

@@ -24,11 +24,15 @@ class RDMACommunicator {
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number,
int block_bytes);
int block_bytes,
int prefill_tp_size,
int prefill_tp_idx);
~RDMACommunicator();
// Connection management
int connect(const std::string& dst_ip, const std::string& dst_port);
int connect(const std::string& dst_ip,
const std::string& dst_port,
int dest_tp_size);
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
// Core functionality
@@ -120,6 +124,8 @@ class RDMACommunicator {
int block_number; // Number of blocks
int block_size_byte; // Size of each block in bytes
int layer_number; // Number of layers
int prefill_tp_size; // tensor parallelism size for prefill
int prefill_tp_idx; // tensor parallelism index for prefill
std::vector<std::vector<void*>>
local_cache_key_ptr_per_layer; // Per-layer key pointers

View File

@@ -41,7 +41,7 @@
* @param local_key_cache Vector of local key cache pointers
* @param local_value_cache Vector of local value cache pointers
* @param block_number Number of blocks in cache
* @param block_bytes Size of each block in bytes
* @param block_bytes Bytes of each block in each tp rank
*
* @throws std::runtime_error If initialization fails
*/
@@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role,
std::vector<int64_t> local_key_cache,
std::vector<int64_t> local_value_cache,
int block_number,
int block_bytes)
int block_bytes,
int prefill_tp_size,
int prefill_tp_idx)
: splitwise_role(role),
gpu_idx(gpu_idx),
port(port),
@@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role,
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
block_number(block_number),
block_size_byte(block_bytes),
prefill_tp_size(prefill_tp_size),
prefill_tp_idx(prefill_tp_idx),
RDMACommunicator_status(0),
rdma_event_channel_epoll_fd(-1) {
try {
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
*
* @param dst_ip Destination IP address
* @param dst_port Destination port
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
* otherwise specifies decode tp_size
* @return ConnStatus::kConnected ConnStatus::kError;
*/
int RDMACommunicator::connect(const std::string& dst_ip,
const std::string& dst_port) {
const std::string& dst_port,
int dest_tp_size = 0) {
std::string url = dst_ip + ":" + dst_port;
// Initialize IB devices if not already done
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ctx->conn.layer_number = layer_number;
ctx->conn.block_number = block_number;
ctx->conn.block_byte_size = block_size_byte;
if (dest_tp_size > 0)
ctx->conn.decode_tp_size = dest_tp_size;
else
ctx->conn.decode_tp_size = prefill_tp_size;
// Get port information for the connection
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
ERR("Couldn't getexchange port infodestinations");
return static_cast<int>(ConnStatus::kError);
} else {
std::lock_guard<std::mutex> lock(mutex_);
ctx->conn.connected = 1;
conn_map[url] = ctx;
client_exchange_mr(ctx);
}
@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
}
}
std::lock_guard<std::mutex> lock(mutex_);
ctx->conn.connected = 1;
conn_map[url] = ctx;
WARN("connect end ....");
return static_cast<int>(ConnStatus::kConnected);
}
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
bool RDMACommunicator::is_connected(const std::string& dst_ip,
const std::string& dst_port) {
std::lock_guard<std::mutex> lock(mutex_);
std::string url = dst_ip + ":" + dst_port;
return conn_map.find(url) != conn_map.end();
}
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
uint32_t cache_value_rkey =
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
uint64_t offset_in_block =
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
uint64_t total_block_size_byte =
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
for (size_t block_index = 0; block_index < block_num; ++block_index) {
char* char_ptr = static_cast<char*>(
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
cache_key_remote_addr[block_index] =
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
cache_key_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
char_ptr = static_cast<char*>(
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
cache_value_remote_addr[block_index] =
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
cache_value_remote_addr[block_index] = (uint64_t(
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
offset_in_block));
}
ctx->conn.wc_target_count = 0;
for (int i = 0; i < 2; ++i) {
bool is_key = (i == 0);

View File

@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
std::vector<int64_t>,
std::vector<int64_t>,
int,
int>())
.def("connect", &RDMACommunicator::connect)
.def("is_connected", &RDMACommunicator::is_connected)
.def("write_cache", &RDMACommunicator::write_cache);
int,
int,
int>(),
py::arg("splitwise_role"),
py::arg("gpu_idx"),
py::arg("port"),
py::arg("key_cache_ptrs"),
py::arg("value_cache_ptrs"),
py::arg("block_number"),
py::arg("block_bytes"),
py::arg("prefill_tp_size") = 1,
py::arg("prefill_tp_idx") = 0)
.def("connect",
&RDMACommunicator::connect,
py::arg("dst_ip"),
py::arg("dst_port"),
py::arg("dst_tp_size") =
0, // Default 0: assumes dest has same tp_size as source;
// otherwise specifies decode tp_size
py::call_guard<py::gil_scoped_release>())
.def("is_connected",
&RDMACommunicator::is_connected,
py::arg("dst_ip"),
py::arg("dst_port"),
py::call_guard<py::gil_scoped_release>())
.def("write_cache",
&RDMACommunicator::write_cache,
py::arg("dst_ip"),
py::arg("dst_port"),
py::arg("local_block_ids"),
py::arg("remote_block_ids"),
py::arg("layer_idx"),
py::call_guard<py::gil_scoped_release>());
#ifdef VERSION_INFO
m.attr("__version__") = VERSION_INFO;

View File

@@ -34,6 +34,8 @@ class RDMACommManager:
max_block_num,
block_bytes,
rdma_port,
prefill_tp_size,
prefill_tp_idx,
):
try:
import rdma_comm
@@ -51,12 +53,16 @@ class RDMACommManager:
cache_v_ptr_list,
max_block_num,
block_bytes,
prefill_tp_size,
prefill_tp_idx,
)
self.splitwise_role = splitwise_role
self.connected_rdma = set()
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
logger.info(
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
)
def connect(self, ip, port):
def connect(self, ip, port, tp_size):
"""
Connect to remote gpu and write cache.
"""
@@ -65,7 +71,7 @@ class RDMACommManager:
if ret:
return True
ret = self.messager.connect(ip, str(port))
ret = self.messager.connect(ip, str(port), tp_size)
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
return ret == 0

View File

@@ -1888,6 +1888,7 @@ class FDConfig:
logger.info(f"disaggregate_info: {self.disaggregate_info}")
if self.router_config:
# the information for registering this server to router
self.register_info = {
"role": self.scheduler_config.splitwise_role,
"host_ip": self.host_ip,
@@ -1897,6 +1898,7 @@ class FDConfig:
"engine_worker_queue_port": engine_worker_queue_port,
"device_ids": self.local_device_ids,
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
"tp_size": self.parallel_config.tensor_parallel_size,
}
logger.info(f"register_info: {self.register_info}")

View File

@@ -150,6 +150,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
}

View File

@@ -95,7 +95,7 @@ class Router:
async def register_instance(self, instance_info_dict: dict):
"""Register an instance asynchronously"""
try:
inst_info = InstanceInfo(**instance_info_dict)
inst_info = InstanceInfo.from_dict(instance_info_dict)
except Exception as e:
logger.error(f"register instance failed: {e}")
raise
@@ -173,11 +173,17 @@ class Router:
logger.debug(f"Received request: {request_data}")
prefill_server, decode_server = await self.select_pd()
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
raise HTTPException(
status_code=400,
detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1",
)
# TODO: unify the disaggregate_info in server and remove redundancy params
is_same_node = prefill_server.host_ip == decode_server.host_ip
use_ipc = (
is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
)
is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
cache_info = {}
if use_ipc:

View File

@@ -15,9 +15,9 @@
"""
import asyncio
from dataclasses import asdict, dataclass, field
from dataclasses import MISSING, asdict, dataclass, field, fields
from enum import Enum
from typing import List, Union
from typing import Any, List, Union
import aiohttp
import requests
@@ -39,6 +39,24 @@ class InstanceInfo:
transfer_protocol: List[str] = field(default_factory=list)
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
device_ids: Union[List[str], List[int]] = field(default_factory=list)
tp_size: int = 1
@classmethod
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
"""Create instance from dict arguments"""
kwargs = {}
for field_def in fields(cls):
name = field_def.name
if name in info_dict:
value = info_dict[name]
else:
# handle default and default_factory
if field_def.default is not MISSING:
value = field_def.default
else:
value = field_def.default_factory()
kwargs[name] = value
return cls(**kwargs)
def __post_init__(self):
"""check and unify fields"""

View File

@@ -199,7 +199,7 @@ class SplitwiseConnector:
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
)
self.logger.info(f"send splitwise tasks to port {addr} decode")
self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}")
self.current_request_ids[task.request_id] = "init"
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
@@ -271,6 +271,7 @@ class SplitwiseConnector:
)
def check_decode_allocated(self, task):
self.logger.debug(f"start check decode allocated: {task.request_id}")
start_time = time.time()
if task.disaggregate_info is None:
return True, ""
@@ -280,7 +281,7 @@ class SplitwiseConnector:
return True, ""
while self.current_request_ids[task.request_id] == "init":
time.sleep(0.001)
if time.time() - start_time > 30:
if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS:
del self.current_request_ids[task.request_id]
return False, "timeout"
msg = self.current_request_ids[task.request_id]
@@ -363,6 +364,7 @@ class SplitwiseConnector:
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
"transfer_protocol": "rdma",
"dest_block_ids": dsg_info["block_tables"],
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
}
addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}"

View File

@@ -0,0 +1,427 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Test splitwise deployment: use local_scheduler + router,
# set ENABLE_V1_KVCACHE_SCHEDULER is 1, use rdma to transfer cache,
# the tp_size of prefill is 2 and the tp_size of decode is 1.
import json
import os
import shutil
import signal
import subprocess
import sys
import time
import pytest
import requests
from utils.serving_utils import (
FD_API_PORT,
FD_CACHE_QUEUE_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
clean,
get_registered_number,
)
# Read ports from environment variables; use default values if not set
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
# List of ports to clean before and after tests
PORTS_TO_CLEAN = [
FD_API_PORT,
FD_ENGINE_QUEUE_PORT,
FD_METRICS_PORT,
FD_CACHE_QUEUE_PORT,
FD_CONNECTOR_PORT,
FD_RDMA_PORT,
FD_RDMA_PORT + 1,
FD_API_PORT + 1,
FD_ENGINE_QUEUE_PORT + 1,
FD_METRICS_PORT + 1,
FD_CACHE_QUEUE_PORT + 1,
FD_CONNECTOR_PORT + 1,
FD_RDMA_PORT + 2,
FD_ROUTER_PORT,
]
@pytest.fixture(scope="session", autouse=True)
def setup_and_run_server():
"""
Pytest fixture that runs once per test session:
- Cleans ports before tests
- Starts the API server as a subprocess
- Waits for server port to open (up to 30 seconds)
- Tears down server after all tests finish
"""
print("Pre-test port cleanup...")
clean(PORTS_TO_CLEAN)
print("log dir clean ")
if os.path.exists("log_router") and os.path.isdir("log_router"):
shutil.rmtree("log_router")
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
shutil.rmtree("log_prefill")
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
shutil.rmtree("log_decode")
base_path = os.getenv("MODEL_PATH")
if base_path:
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
print(f"model_path: {model_path}")
# get rdma nics
current_dir = os.path.dirname(os.path.abspath(__file__))
shell_path = os.path.join(current_dir, "utils/get_rdma_nics.sh")
output = subprocess.check_output(["bash", shell_path, "gpu"], text=True)
_, rdma_nics = output.split("=")
print(f"shell_path: {shell_path}, rdma_nics: {rdma_nics}")
# router
print("start router...")
env_router = os.environ.copy()
env_router["FD_LOG_DIR"] = "log_router"
router_log_path = "router.log"
router_cmd = [
sys.executable,
"-m",
"fastdeploy.router.launch",
"--port",
str(FD_ROUTER_PORT),
"--splitwise",
]
with open(router_log_path, "w") as logfile:
process_router = subprocess.Popen(
router_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_router,
)
# prefill实例
print("start prefill...")
env_prefill = os.environ.copy()
env_prefill["CUDA_VISIBLE_DEVICES"] = "0,1"
env_prefill["FD_LOG_DIR"] = "log_prefill"
env_prefill["KVCACHE_RDMA_NICS"] = rdma_nics
prefill_log_path = "prefill.log"
prefill_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT),
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT),
"--metrics-port",
str(FD_METRICS_PORT),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT),
"--max-model-len",
"8192",
"--tensor-parallel-size",
"2",
"--splitwise-role",
"prefill",
"--cache-transfer-protocol",
"rdma",
"--rdma-comm-ports",
f"{FD_RDMA_PORT},{FD_RDMA_PORT+1}",
"--pd-comm-port",
str(FD_CONNECTOR_PORT),
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(prefill_log_path, "w") as logfile:
process_prefill = subprocess.Popen(
prefill_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_prefill,
)
time.sleep(1)
# decode实例
print("start decode...")
env_decode = os.environ.copy()
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
env_decode["FD_LOG_DIR"] = "log_decode"
env_decode["KVCACHE_RDMA_NICS"] = rdma_nics
decode_log_path = "decode.log"
decode_cmd = [
sys.executable,
"-m",
"fastdeploy.entrypoints.openai.api_server",
"--model",
model_path,
"--port",
str(FD_API_PORT + 1),
"--engine-worker-queue-port",
str(FD_ENGINE_QUEUE_PORT + 1),
"--metrics-port",
str(FD_METRICS_PORT + 1),
"--cache-queue-port",
str(FD_CACHE_QUEUE_PORT + 1),
"--max-model-len",
"8192",
"--splitwise-role",
"decode",
"--cache-transfer-protocol",
"rdma",
"--rdma-comm-ports",
str(FD_RDMA_PORT + 2),
"--pd-comm-port",
str(FD_CONNECTOR_PORT + 1),
"--router",
f"0.0.0.0:{FD_ROUTER_PORT}",
]
# Start subprocess in new process group
with open(decode_log_path, "w") as logfile:
process_decode = subprocess.Popen(
decode_cmd,
stdout=logfile,
stderr=subprocess.STDOUT,
start_new_session=True, # Enables killing full group via os.killpg
env=env_decode,
)
# Wait up to 300 seconds for API server to be ready
for _ in range(60):
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1:
print("Prefill and decode servers are both online")
break
time.sleep(5)
else:
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
try:
os.killpg(process_router.pid, signal.SIGTERM)
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
except Exception as e:
print(f"Failed to kill process group: {e}")
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
yield # Run tests
print("\n===== Post-test server cleanup... =====")
try:
os.killpg(process_router.pid, signal.SIGTERM)
os.killpg(process_prefill.pid, signal.SIGTERM)
os.killpg(process_decode.pid, signal.SIGTERM)
clean(PORTS_TO_CLEAN)
print(f"Prefill server (pid={process_prefill.pid}) terminated")
print(f"Decode server (pid={process_decode.pid}) terminated")
except Exception as e:
print(f"Failed to terminate API server: {e}")
@pytest.fixture(scope="session")
def api_url(request):
"""
Returns the API endpoint URL for chat completions.
"""
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
@pytest.fixture(scope="session")
def metrics_url(request):
"""
Returns the metrics endpoint URL.
"""
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
@pytest.fixture
def headers():
"""
Returns common HTTP request headers.
"""
return {"Content-Type": "application/json"}
def test_metrics_config(metrics_url):
timeout = 600
url = metrics_url.replace("metrics", "config-info")
res = requests.get(url, timeout=timeout)
assert res.status_code == 200
def send_request(url, payload, timeout=600):
"""
发送请求到指定的URL并返回响应结果。
"""
headers = {
"Content-Type": "application/json",
}
try:
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
print("🟢 接收响应中...\n")
return res
except requests.exceptions.Timeout:
print(f"❌ 请求超时(超过 {timeout} 秒)")
return None
except requests.exceptions.RequestException as e:
print(f"❌ 请求失败:{e}")
return None
def get_stream_chunks(response):
"""解析流式返回生成chunk List[dict]"""
chunks = []
if response.status_code == 200:
for line in response.iter_lines(decode_unicode=True):
if line:
if line.startswith("data: "):
line = line[len("data: ") :]
if line.strip() == "[DONE]":
break
try:
chunk = json.loads(line)
chunks.append(chunk)
except Exception as e:
print(f"解析失败: {e}, 行内容: {line}")
else:
print(f"请求失败,状态码: {response.status_code}")
print("返回内容:", response.text)
return chunks
def test_chat_usage_stream(api_url):
"""测试流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_chat_usage_non_stream(api_url):
"""测试非流式chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
],
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["message"]["content"]
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_stream(api_url):
"""测试流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": True,
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload)
chunks = get_stream_chunks(response)
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
print("Decode Response:", result)
assert result != "", "结果为空"
usage = chunks[-1]["usage"]
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
def test_non_chat_usage_non_stream(api_url):
"""测试非流式非chat usage"""
payload = {
"model": "default",
"temperature": 0,
"top_p": 0,
"seed": 33,
"prompt": "牛顿的三大运动定律是什么?",
"max_tokens": 50,
"stream": False,
"metadata": {"min_tokens": 10},
}
api_url = api_url.replace("chat/completions", "completions")
response = send_request(url=api_url, payload=payload).json()
usage = response["usage"]
result = response["choices"][0]["text"]
print("Decode Response:", result)
assert result != "", "结果为空"
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"