From 0925d44f182315ab5195f6f6cea6f2ecc506cf73 Mon Sep 17 00:00:00 2001 From: Juncai <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:50:20 +0800 Subject: [PATCH] [PD Disaggregation] support different tp_size for prefill and decode (#5296) * up * up * up * fix --- fastdeploy/cache_manager/cache_messager.py | 47 +- .../include/kvcache_connection.h | 1 + .../kvcache_transfer/include/kvcache_rdma.h | 10 +- .../kvcache_transfer/src/kvcache_rdma.cpp | 41 +- .../kvcache_transfer/src/pybind.cpp | 37 +- .../transfer_factory/rdma_cache_transfer.py | 12 +- fastdeploy/config.py | 2 + fastdeploy/envs.py | 1 + fastdeploy/router/router.py | 14 +- fastdeploy/router/utils.py | 22 +- fastdeploy/splitwise/splitwise_connector.py | 6 +- ...> test_ernie_03b_pd_router_v1_rdma_tp1.py} | 0 .../test_ernie_03b_pd_router_v1_rdma_tp2.py | 427 ++++++++++++++++++ 13 files changed, 584 insertions(+), 36 deletions(-) rename tests/e2e/{test_ernie_03b_pd_router_v1_rdma.py => test_ernie_03b_pd_router_v1_rdma_tp1.py} (100%) create mode 100644 tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index 8f8318b22..23b6b72f3 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -55,13 +55,13 @@ def parse_args(): default="mixed", help="splitwise role, can be decode, prefill or mixed", ) - parser.add_argument("--rank", type=int, default=0, help="current rank") + parser.add_argument("--rank", type=int, default=0, help="local tp rank id") parser.add_argument("--device_id", type=int, default=0, help="device id") parser.add_argument("--num_layers", type=int, default=1, help="model num layers") parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape") parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape") parser.add_argument("--rdma_port", type=str, default="", help="rmda port") - parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel") + parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num") parser.add_argument("--engine_pid", type=str, default=None, help="engine pid") parser.add_argument( "--protocol", @@ -208,6 +208,8 @@ class CacheMessager: max_block_num, block_bytes, rdma_port, + nranks, + rank, ) self.gpu_id = gpu_id @@ -507,6 +509,8 @@ class CacheMessagerV1: max_block_num, block_bytes, rdma_port, + nranks, + rank, ) self.gpu_id = gpu_id @@ -595,6 +599,7 @@ class CacheMessagerV1: block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end) block_start_end_list.append((block_id_start, block_id_end)) current_prefilled_token_num_list.append(prefilled_token_num) + while True: # from layer0 to last layer sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"] start_layer_idx = sended_layer_idx + 1 @@ -633,13 +638,27 @@ class CacheMessagerV1: current_transfer_protocol = task["transfer_protocol"] if task["transfer_protocol"] == "rdma": target_ip = task["ip"] - target_id = int(task["rdma_ports"][self.rank]) + # Default decode_tp_size to prefill tp_size (self.nranks) if not specified + decode_tp_size = task.get("decode_tp_size", self.nranks) + if len(task["rdma_ports"]) == self.nranks: + target_id = int(task["rdma_ports"][self.rank]) + elif len(task["rdma_ports"]) == 1: + target_id = task["rdma_ports"][0] + else: + task["status"] = "the tp_size of prefill and decode is mismatch" + continue + if "error" in task["status"]: continue # TODO: use is connected to check if the connection is still alive - logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}") - status = self.messager[current_transfer_protocol].connect(target_ip, target_id) + logger.debug( + f"rdma, start connect decode, {target_ip}:{target_id}, " + f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}" + ) + status = self.messager[current_transfer_protocol].connect( + target_ip, target_id, decode_tp_size + ) if status: logger.info(f"connect to {target_ip}:{target_id} success") else: @@ -762,12 +781,22 @@ class CacheMessagerV1: self.engine_worker_queue.connect_task_barrier.wait() logger.info(f"_handle_connect_task recv task: {task}") task_id = task["task_id"] - ip, rdma_port = task["ip"], task["rdma_ports"][self.rank] - status = self.messager["rdma"].connect(ip, rdma_port) - if not status: + ip = task["ip"] + # Default decode_tp_size to self.nranks (number of ranks) if not specified in the task. + decode_tp_size = task.get("decode_tp_size", self.nranks) + rdma_ports = task["rdma_ports"] + rdma_ports_len = len(rdma_ports) + if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks): + # TODO: support other cases + logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}") response = {"task_id": task_id, "success": False} else: - response = {"task_id": task_id, "success": True} + port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank] + status = self.messager["rdma"].connect(ip, port, decode_tp_size) + if not status: + response = {"task_id": task_id, "success": False} + else: + response = {"task_id": task_id, "success": True} self.engine_worker_queue.connect_task_response_barrier.wait() self.engine_worker_queue.put_connect_rdma_task_response(response) except Exception as e: diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h index 1e94e0824..d9b442a0a 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_connection.h @@ -142,6 +142,7 @@ struct Connection { int wc_target_count; // Configuration + int decode_tp_size; int layer_number; int block_number; int block_byte_size; diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h index e0251f8d4..3a5b2dc78 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/include/kvcache_rdma.h @@ -24,11 +24,15 @@ class RDMACommunicator { std::vector local_key_cache, std::vector local_value_cache, int block_number, - int block_bytes); + int block_bytes, + int prefill_tp_size, + int prefill_tp_idx); ~RDMACommunicator(); // Connection management - int connect(const std::string& dst_ip, const std::string& dst_port); + int connect(const std::string& dst_ip, + const std::string& dst_port, + int dest_tp_size); bool is_connected(const std::string& dst_ip, const std::string& dst_port); // Core functionality @@ -120,6 +124,8 @@ class RDMACommunicator { int block_number; // Number of blocks int block_size_byte; // Size of each block in bytes int layer_number; // Number of layers + int prefill_tp_size; // tensor parallelism size for prefill + int prefill_tp_idx; // tensor parallelism index for prefill std::vector> local_cache_key_ptr_per_layer; // Per-layer key pointers diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp index 4e443872a..60f06bf06 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/kvcache_rdma.cpp @@ -41,7 +41,7 @@ * @param local_key_cache Vector of local key cache pointers * @param local_value_cache Vector of local value cache pointers * @param block_number Number of blocks in cache - * @param block_bytes Size of each block in bytes + * @param block_bytes Bytes of each block in each tp rank * * @throws std::runtime_error If initialization fails */ @@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role, std::vector local_key_cache, std::vector local_value_cache, int block_number, - int block_bytes) + int block_bytes, + int prefill_tp_size, + int prefill_tp_idx) : splitwise_role(role), gpu_idx(gpu_idx), port(port), @@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role, local_cache_value_ptr_layer_head_(std::move(local_value_cache)), block_number(block_number), block_size_byte(block_bytes), + prefill_tp_size(prefill_tp_size), + prefill_tp_idx(prefill_tp_idx), RDMACommunicator_status(0), rdma_event_channel_epoll_fd(-1) { try { @@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() { * * @param dst_ip Destination IP address * @param dst_port Destination port + * @param dest_tp_size Default 0: assumes dest has same tp_size as source; + * otherwise specifies decode tp_size * @return ConnStatus::kConnected ConnStatus::kError; */ int RDMACommunicator::connect(const std::string& dst_ip, - const std::string& dst_port) { + const std::string& dst_port, + int dest_tp_size = 0) { std::string url = dst_ip + ":" + dst_port; // Initialize IB devices if not already done @@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip, ctx->conn.layer_number = layer_number; ctx->conn.block_number = block_number; ctx->conn.block_byte_size = block_size_byte; + if (dest_tp_size > 0) + ctx->conn.decode_tp_size = dest_tp_size; + else + ctx->conn.decode_tp_size = prefill_tp_size; // Get port information for the connection if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) { @@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip, ERR("Couldn't getexchange port infodestinations"); return static_cast(ConnStatus::kError); } else { - std::lock_guard lock(mutex_); - ctx->conn.connected = 1; - conn_map[url] = ctx; client_exchange_mr(ctx); } @@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip, } } + std::lock_guard lock(mutex_); + ctx->conn.connected = 1; + conn_map[url] = ctx; + WARN("connect end ...."); return static_cast(ConnStatus::kConnected); } @@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() { bool RDMACommunicator::is_connected(const std::string& dst_ip, const std::string& dst_port) { + std::lock_guard lock(mutex_); std::string url = dst_ip + ":" + dst_port; return conn_map.find(url) != conn_map.end(); } @@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip, uint32_t cache_value_rkey = ctx->conn.write_cache_value_remote_rkey_list[layer_idx]; uint32_t crc_cache_key_rkey, crc_cache_value_rkey; + bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size; + uint64_t offset_in_block = + pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx; + uint64_t total_block_size_byte = + pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size; for (size_t block_index = 0; block_index < block_num; ++block_index) { char* char_ptr = static_cast( ctx->conn.write_cache_key_remote_ptr_list[layer_idx]); - cache_key_remote_addr[block_index] = - (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); + cache_key_remote_addr[block_index] = (uint64_t( + char_ptr + remote_block_ids[block_index] * total_block_size_byte + + offset_in_block)); char_ptr = static_cast( ctx->conn.write_cache_value_remote_ptr_list[layer_idx]); - cache_value_remote_addr[block_index] = - (uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte)); + cache_value_remote_addr[block_index] = (uint64_t( + char_ptr + remote_block_ids[block_index] * total_block_size_byte + + offset_in_block)); } + ctx->conn.wc_target_count = 0; for (int i = 0; i < 2; ++i) { bool is_key = (i == 0); diff --git a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp index 9ffcb35b2..9b42f34f7 100644 --- a/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp +++ b/fastdeploy/cache_manager/transfer_factory/kvcache_transfer/src/pybind.cpp @@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) { std::vector, std::vector, int, - int>()) - .def("connect", &RDMACommunicator::connect) - .def("is_connected", &RDMACommunicator::is_connected) - .def("write_cache", &RDMACommunicator::write_cache); + int, + int, + int>(), + py::arg("splitwise_role"), + py::arg("gpu_idx"), + py::arg("port"), + py::arg("key_cache_ptrs"), + py::arg("value_cache_ptrs"), + py::arg("block_number"), + py::arg("block_bytes"), + py::arg("prefill_tp_size") = 1, + py::arg("prefill_tp_idx") = 0) + .def("connect", + &RDMACommunicator::connect, + py::arg("dst_ip"), + py::arg("dst_port"), + py::arg("dst_tp_size") = + 0, // Default 0: assumes dest has same tp_size as source; + // otherwise specifies decode tp_size + py::call_guard()) + .def("is_connected", + &RDMACommunicator::is_connected, + py::arg("dst_ip"), + py::arg("dst_port"), + py::call_guard()) + .def("write_cache", + &RDMACommunicator::write_cache, + py::arg("dst_ip"), + py::arg("dst_port"), + py::arg("local_block_ids"), + py::arg("remote_block_ids"), + py::arg("layer_idx"), + py::call_guard()); #ifdef VERSION_INFO m.attr("__version__") = VERSION_INFO; diff --git a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py index 6a0c0ac36..0548e8f84 100644 --- a/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py +++ b/fastdeploy/cache_manager/transfer_factory/rdma_cache_transfer.py @@ -34,6 +34,8 @@ class RDMACommManager: max_block_num, block_bytes, rdma_port, + prefill_tp_size, + prefill_tp_idx, ): try: import rdma_comm @@ -51,12 +53,16 @@ class RDMACommManager: cache_v_ptr_list, max_block_num, block_bytes, + prefill_tp_size, + prefill_tp_idx, ) self.splitwise_role = splitwise_role self.connected_rdma = set() - logger.info(f"init rdma messager {gpu_id} {rdma_port}") + logger.info( + f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}" + ) - def connect(self, ip, port): + def connect(self, ip, port, tp_size): """ Connect to remote gpu and write cache. """ @@ -65,7 +71,7 @@ class RDMACommManager: if ret: return True - ret = self.messager.connect(ip, str(port)) + ret = self.messager.connect(ip, str(port), tp_size) logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}") return ret == 0 diff --git a/fastdeploy/config.py b/fastdeploy/config.py index aea9bf33d..2870b9816 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -1888,6 +1888,7 @@ class FDConfig: logger.info(f"disaggregate_info: {self.disaggregate_info}") if self.router_config: + # the information for registering this server to router self.register_info = { "role": self.scheduler_config.splitwise_role, "host_ip": self.host_ip, @@ -1897,6 +1898,7 @@ class FDConfig: "engine_worker_queue_port": engine_worker_queue_port, "device_ids": self.local_device_ids, "transfer_protocol": self.cache_config.cache_transfer_protocol.split(","), + "tp_size": self.parallel_config.tensor_parallel_size, } logger.info(f"register_info: {self.register_info}") diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index c46744e00..414b5abe8 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -150,6 +150,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")), # "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU" "FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")), + "FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")), } diff --git a/fastdeploy/router/router.py b/fastdeploy/router/router.py index 23d40e00e..ffeb6b3f3 100644 --- a/fastdeploy/router/router.py +++ b/fastdeploy/router/router.py @@ -95,7 +95,7 @@ class Router: async def register_instance(self, instance_info_dict: dict): """Register an instance asynchronously""" try: - inst_info = InstanceInfo(**instance_info_dict) + inst_info = InstanceInfo.from_dict(instance_info_dict) except Exception as e: logger.error(f"register instance failed: {e}") raise @@ -173,11 +173,17 @@ class Router: logger.debug(f"Received request: {request_data}") prefill_server, decode_server = await self.select_pd() + if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1: + raise HTTPException( + status_code=400, + detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1", + ) + # TODO: unify the disaggregate_info in server and remove redundancy params is_same_node = prefill_server.host_ip == decode_server.host_ip - use_ipc = ( - is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol - ) + is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol + is_same_tp_size = prefill_server.tp_size == decode_server.tp_size + use_ipc = is_same_node and is_support_ipc and is_same_tp_size cache_info = {} if use_ipc: diff --git a/fastdeploy/router/utils.py b/fastdeploy/router/utils.py index 7c83db90f..596c0ceec 100644 --- a/fastdeploy/router/utils.py +++ b/fastdeploy/router/utils.py @@ -15,9 +15,9 @@ """ import asyncio -from dataclasses import asdict, dataclass, field +from dataclasses import MISSING, asdict, dataclass, field, fields from enum import Enum -from typing import List, Union +from typing import Any, List, Union import aiohttp import requests @@ -39,6 +39,24 @@ class InstanceInfo: transfer_protocol: List[str] = field(default_factory=list) rdma_ports: Union[List[str], List[int]] = field(default_factory=list) device_ids: Union[List[str], List[int]] = field(default_factory=list) + tp_size: int = 1 + + @classmethod + def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo": + """Create instance from dict arguments""" + kwargs = {} + for field_def in fields(cls): + name = field_def.name + if name in info_dict: + value = info_dict[name] + else: + # handle default and default_factory + if field_def.default is not MISSING: + value = field_def.default + else: + value = field_def.default_factory() + kwargs[name] = value + return cls(**kwargs) def __post_init__(self): """check and unify fields""" diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index 9c13f26d8..6fbc40729 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -199,7 +199,7 @@ class SplitwiseConnector: f"{task.disaggregate_info['cache_info']['rdma']['ip']}:" + f"{task.disaggregate_info['cache_info']['rdma']['port']}" ) - self.logger.info(f"send splitwise tasks to port {addr} decode") + self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}") self.current_request_ids[task.request_id] = "init" decode_diagg = task.disaggregate_info["cache_info"] task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"] @@ -271,6 +271,7 @@ class SplitwiseConnector: ) def check_decode_allocated(self, task): + self.logger.debug(f"start check decode allocated: {task.request_id}") start_time = time.time() if task.disaggregate_info is None: return True, "" @@ -280,7 +281,7 @@ class SplitwiseConnector: return True, "" while self.current_request_ids[task.request_id] == "init": time.sleep(0.001) - if time.time() - start_time > 30: + if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS: del self.current_request_ids[task.request_id] return False, "timeout" msg = self.current_request_ids[task.request_id] @@ -363,6 +364,7 @@ class SplitwiseConnector: "rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"], "transfer_protocol": "rdma", "dest_block_ids": dsg_info["block_tables"], + "decode_tp_size": self.cfg.parallel_config.tensor_parallel_size, } addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}" diff --git a/tests/e2e/test_ernie_03b_pd_router_v1_rdma.py b/tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp1.py similarity index 100% rename from tests/e2e/test_ernie_03b_pd_router_v1_rdma.py rename to tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp1.py diff --git a/tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py b/tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py new file mode 100644 index 000000000..e16456f77 --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py @@ -0,0 +1,427 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment: use local_scheduler + router, +# set ENABLE_V1_KVCACHE_SCHEDULER is 1, use rdma to transfer cache, +# the tp_size of prefill is 2 and the tp_size of decode is 1. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + get_registered_number, +) + +# Read ports from environment variables; use default values if not set +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533)) +FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_RDMA_PORT, + FD_RDMA_PORT + 1, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_RDMA_PORT + 2, + FD_ROUTER_PORT, +] + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean(PORTS_TO_CLEAN) + + print("log dir clean ") + if os.path.exists("log_router") and os.path.isdir("log_router"): + shutil.rmtree("log_router") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + # get rdma nics + current_dir = os.path.dirname(os.path.abspath(__file__)) + shell_path = os.path.join(current_dir, "utils/get_rdma_nics.sh") + output = subprocess.check_output(["bash", shell_path, "gpu"], text=True) + _, rdma_nics = output.split("=") + print(f"shell_path: {shell_path}, rdma_nics: {rdma_nics}") + + # router + print("start router...") + env_router = os.environ.copy() + env_router["FD_LOG_DIR"] = "log_router" + router_log_path = "router.log" + + router_cmd = [ + sys.executable, + "-m", + "fastdeploy.router.launch", + "--port", + str(FD_ROUTER_PORT), + "--splitwise", + ] + + with open(router_log_path, "w") as logfile: + process_router = subprocess.Popen( + router_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_router, + ) + + # prefill实例 + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0,1" + env_prefill["FD_LOG_DIR"] = "log_prefill" + env_prefill["KVCACHE_RDMA_NICS"] = rdma_nics + + prefill_log_path = "prefill.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--tensor-parallel-size", + "2", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + f"{FD_RDMA_PORT},{FD_RDMA_PORT+1}", + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_prefill, + ) + time.sleep(1) + + # decode实例 + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["FD_LOG_DIR"] = "log_decode" + env_decode["KVCACHE_RDMA_NICS"] = rdma_nics + + decode_log_path = "decode.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "rdma", + "--rdma-comm-ports", + str(FD_RDMA_PORT + 2), + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_decode, + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(60): + registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}") + if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def test_metrics_config(metrics_url): + timeout = 600 + url = metrics_url.replace("metrics", "config-info") + res = requests.get(url, timeout=timeout) + assert res.status_code == 200 + + +def send_request(url, payload, timeout=600): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"