mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
This commit is contained in:
@@ -55,13 +55,13 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank id")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
@@ -208,6 +208,8 @@ class CacheMessager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -507,6 +509,8 @@ class CacheMessagerV1:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -595,6 +599,7 @@ class CacheMessagerV1:
|
||||
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
@@ -633,13 +638,27 @@ class CacheMessagerV1:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
if len(task["rdma_ports"]) == self.nranks:
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
elif len(task["rdma_ports"]) == 1:
|
||||
target_id = task["rdma_ports"][0]
|
||||
else:
|
||||
task["status"] = "the tp_size of prefill and decode is mismatch"
|
||||
continue
|
||||
|
||||
if "error" in task["status"]:
|
||||
continue
|
||||
|
||||
# TODO: use is connected to check if the connection is still alive
|
||||
logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}")
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
logger.debug(
|
||||
f"rdma, start connect decode, {target_ip}:{target_id}, "
|
||||
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
|
||||
)
|
||||
status = self.messager[current_transfer_protocol].connect(
|
||||
target_ip, target_id, decode_tp_size
|
||||
)
|
||||
if status:
|
||||
logger.info(f"connect to {target_ip}:{target_id} success")
|
||||
else:
|
||||
@@ -762,12 +781,22 @@ class CacheMessagerV1:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
ip = task["ip"]
|
||||
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
rdma_ports = task["rdma_ports"]
|
||||
rdma_ports_len = len(rdma_ports)
|
||||
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
|
||||
# TODO: support other cases
|
||||
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
|
||||
@@ -142,6 +142,7 @@ struct Connection {
|
||||
int wc_target_count;
|
||||
|
||||
// Configuration
|
||||
int decode_tp_size;
|
||||
int layer_number;
|
||||
int block_number;
|
||||
int block_byte_size;
|
||||
|
||||
@@ -24,11 +24,15 @@ class RDMACommunicator {
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes);
|
||||
int block_bytes,
|
||||
int prefill_tp_size,
|
||||
int prefill_tp_idx);
|
||||
~RDMACommunicator();
|
||||
|
||||
// Connection management
|
||||
int connect(const std::string& dst_ip, const std::string& dst_port);
|
||||
int connect(const std::string& dst_ip,
|
||||
const std::string& dst_port,
|
||||
int dest_tp_size);
|
||||
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
|
||||
|
||||
// Core functionality
|
||||
@@ -120,6 +124,8 @@ class RDMACommunicator {
|
||||
int block_number; // Number of blocks
|
||||
int block_size_byte; // Size of each block in bytes
|
||||
int layer_number; // Number of layers
|
||||
int prefill_tp_size; // tensor parallelism size for prefill
|
||||
int prefill_tp_idx; // tensor parallelism index for prefill
|
||||
|
||||
std::vector<std::vector<void*>>
|
||||
local_cache_key_ptr_per_layer; // Per-layer key pointers
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
* @param local_key_cache Vector of local key cache pointers
|
||||
* @param local_value_cache Vector of local value cache pointers
|
||||
* @param block_number Number of blocks in cache
|
||||
* @param block_bytes Size of each block in bytes
|
||||
* @param block_bytes Bytes of each block in each tp rank
|
||||
*
|
||||
* @throws std::runtime_error If initialization fails
|
||||
*/
|
||||
@@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes)
|
||||
int block_bytes,
|
||||
int prefill_tp_size,
|
||||
int prefill_tp_idx)
|
||||
: splitwise_role(role),
|
||||
gpu_idx(gpu_idx),
|
||||
port(port),
|
||||
@@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
|
||||
block_number(block_number),
|
||||
block_size_byte(block_bytes),
|
||||
prefill_tp_size(prefill_tp_size),
|
||||
prefill_tp_idx(prefill_tp_idx),
|
||||
RDMACommunicator_status(0),
|
||||
rdma_event_channel_epoll_fd(-1) {
|
||||
try {
|
||||
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
|
||||
*
|
||||
* @param dst_ip Destination IP address
|
||||
* @param dst_port Destination port
|
||||
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
|
||||
* otherwise specifies decode tp_size
|
||||
* @return ConnStatus::kConnected ConnStatus::kError;
|
||||
*/
|
||||
|
||||
int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
const std::string& dst_port,
|
||||
int dest_tp_size = 0) {
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
|
||||
// Initialize IB devices if not already done
|
||||
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ctx->conn.layer_number = layer_number;
|
||||
ctx->conn.block_number = block_number;
|
||||
ctx->conn.block_byte_size = block_size_byte;
|
||||
if (dest_tp_size > 0)
|
||||
ctx->conn.decode_tp_size = dest_tp_size;
|
||||
else
|
||||
ctx->conn.decode_tp_size = prefill_tp_size;
|
||||
|
||||
// Get port information for the connection
|
||||
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
|
||||
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ERR("Couldn't getexchange port infodestinations");
|
||||
return static_cast<int>(ConnStatus::kError);
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
client_exchange_mr(ctx);
|
||||
}
|
||||
|
||||
@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
|
||||
WARN("connect end ....");
|
||||
return static_cast<int>(ConnStatus::kConnected);
|
||||
}
|
||||
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
|
||||
|
||||
bool RDMACommunicator::is_connected(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
return conn_map.find(url) != conn_map.end();
|
||||
}
|
||||
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
||||
uint32_t cache_value_rkey =
|
||||
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
||||
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
||||
uint64_t offset_in_block =
|
||||
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
|
||||
uint64_t total_block_size_byte =
|
||||
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
|
||||
|
||||
for (size_t block_index = 0; block_index < block_num; ++block_index) {
|
||||
char* char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
|
||||
cache_key_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_key_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_value_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
}
|
||||
|
||||
ctx->conn.wc_target_count = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
bool is_key = (i == 0);
|
||||
|
||||
@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int,
|
||||
int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
int,
|
||||
int,
|
||||
int>(),
|
||||
py::arg("splitwise_role"),
|
||||
py::arg("gpu_idx"),
|
||||
py::arg("port"),
|
||||
py::arg("key_cache_ptrs"),
|
||||
py::arg("value_cache_ptrs"),
|
||||
py::arg("block_number"),
|
||||
py::arg("block_bytes"),
|
||||
py::arg("prefill_tp_size") = 1,
|
||||
py::arg("prefill_tp_idx") = 0)
|
||||
.def("connect",
|
||||
&RDMACommunicator::connect,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("dst_tp_size") =
|
||||
0, // Default 0: assumes dest has same tp_size as source;
|
||||
// otherwise specifies decode tp_size
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_connected",
|
||||
&RDMACommunicator::is_connected,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("write_cache",
|
||||
&RDMACommunicator::write_cache,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("local_block_ids"),
|
||||
py::arg("remote_block_ids"),
|
||||
py::arg("layer_idx"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
|
||||
@@ -34,6 +34,8 @@ class RDMACommManager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
):
|
||||
try:
|
||||
import rdma_comm
|
||||
@@ -51,12 +53,16 @@ class RDMACommManager:
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.connected_rdma = set()
|
||||
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
|
||||
logger.info(
|
||||
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
|
||||
)
|
||||
|
||||
def connect(self, ip, port):
|
||||
def connect(self, ip, port, tp_size):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
@@ -65,7 +71,7 @@ class RDMACommManager:
|
||||
if ret:
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
ret = self.messager.connect(ip, str(port), tp_size)
|
||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
return ret == 0
|
||||
|
||||
|
||||
@@ -1888,6 +1888,7 @@ class FDConfig:
|
||||
logger.info(f"disaggregate_info: {self.disaggregate_info}")
|
||||
|
||||
if self.router_config:
|
||||
# the information for registering this server to router
|
||||
self.register_info = {
|
||||
"role": self.scheduler_config.splitwise_role,
|
||||
"host_ip": self.host_ip,
|
||||
@@ -1897,6 +1898,7 @@ class FDConfig:
|
||||
"engine_worker_queue_port": engine_worker_queue_port,
|
||||
"device_ids": self.local_device_ids,
|
||||
"transfer_protocol": self.cache_config.cache_transfer_protocol.split(","),
|
||||
"tp_size": self.parallel_config.tensor_parallel_size,
|
||||
}
|
||||
logger.info(f"register_info: {self.register_info}")
|
||||
|
||||
|
||||
@@ -150,6 +150,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
|
||||
# "Number of tokens in the group for Mixture of Experts (MoE) computation processing on HPU"
|
||||
"FD_HPU_CHUNK_SIZE": lambda: int(os.getenv("FD_HPU_CHUNK_SIZE", "64")),
|
||||
"FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS": lambda: int(os.getenv("FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS", "30")),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ class Router:
|
||||
async def register_instance(self, instance_info_dict: dict):
|
||||
"""Register an instance asynchronously"""
|
||||
try:
|
||||
inst_info = InstanceInfo(**instance_info_dict)
|
||||
inst_info = InstanceInfo.from_dict(instance_info_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"register instance failed: {e}")
|
||||
raise
|
||||
@@ -173,11 +173,17 @@ class Router:
|
||||
logger.debug(f"Received request: {request_data}")
|
||||
prefill_server, decode_server = await self.select_pd()
|
||||
|
||||
if prefill_server.tp_size != decode_server.tp_size and decode_server.tp_size != 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="The tp_size of prefill and decode should be equal or the tp_size of decode is 1",
|
||||
)
|
||||
|
||||
# TODO: unify the disaggregate_info in server and remove redundancy params
|
||||
is_same_node = prefill_server.host_ip == decode_server.host_ip
|
||||
use_ipc = (
|
||||
is_same_node and "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
|
||||
)
|
||||
is_support_ipc = "ipc" in prefill_server.transfer_protocol and "ipc" in decode_server.transfer_protocol
|
||||
is_same_tp_size = prefill_server.tp_size == decode_server.tp_size
|
||||
use_ipc = is_same_node and is_support_ipc and is_same_tp_size
|
||||
|
||||
cache_info = {}
|
||||
if use_ipc:
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import MISSING, asdict, dataclass, field, fields
|
||||
from enum import Enum
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
@@ -39,6 +39,24 @@ class InstanceInfo:
|
||||
transfer_protocol: List[str] = field(default_factory=list)
|
||||
rdma_ports: Union[List[str], List[int]] = field(default_factory=list)
|
||||
device_ids: Union[List[str], List[int]] = field(default_factory=list)
|
||||
tp_size: int = 1
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, info_dict: dict[str, Any]) -> "InstanceInfo":
|
||||
"""Create instance from dict arguments"""
|
||||
kwargs = {}
|
||||
for field_def in fields(cls):
|
||||
name = field_def.name
|
||||
if name in info_dict:
|
||||
value = info_dict[name]
|
||||
else:
|
||||
# handle default and default_factory
|
||||
if field_def.default is not MISSING:
|
||||
value = field_def.default
|
||||
else:
|
||||
value = field_def.default_factory()
|
||||
kwargs[name] = value
|
||||
return cls(**kwargs)
|
||||
|
||||
def __post_init__(self):
|
||||
"""check and unify fields"""
|
||||
|
||||
@@ -199,7 +199,7 @@ class SplitwiseConnector:
|
||||
f"{task.disaggregate_info['cache_info']['rdma']['ip']}:"
|
||||
+ f"{task.disaggregate_info['cache_info']['rdma']['port']}"
|
||||
)
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode")
|
||||
self.logger.info(f"send splitwise tasks to port {addr} decode, {task.request_id}")
|
||||
self.current_request_ids[task.request_id] = "init"
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
@@ -271,6 +271,7 @@ class SplitwiseConnector:
|
||||
)
|
||||
|
||||
def check_decode_allocated(self, task):
|
||||
self.logger.debug(f"start check decode allocated: {task.request_id}")
|
||||
start_time = time.time()
|
||||
if task.disaggregate_info is None:
|
||||
return True, ""
|
||||
@@ -280,7 +281,7 @@ class SplitwiseConnector:
|
||||
return True, ""
|
||||
while self.current_request_ids[task.request_id] == "init":
|
||||
time.sleep(0.001)
|
||||
if time.time() - start_time > 30:
|
||||
if time.time() - start_time > envs.FD_PREFILL_WAIT_DECODE_RESOURCE_SECONDS:
|
||||
del self.current_request_ids[task.request_id]
|
||||
return False, "timeout"
|
||||
msg = self.current_request_ids[task.request_id]
|
||||
@@ -363,6 +364,7 @@ class SplitwiseConnector:
|
||||
"rdma_ports": self.cfg.disaggregate_info["cache_info"]["rdma"]["rdma_port"],
|
||||
"transfer_protocol": "rdma",
|
||||
"dest_block_ids": dsg_info["block_tables"],
|
||||
"decode_tp_size": self.cfg.parallel_config.tensor_parallel_size,
|
||||
}
|
||||
|
||||
addr = f"{dsg_info['cache_info']['rdma']['ip']}:" + f"{dsg_info['cache_info']['rdma']['port']}"
|
||||
|
||||
427
tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py
Normal file
427
tests/e2e/test_ernie_03b_pd_router_v1_rdma_tp2.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Test splitwise deployment: use local_scheduler + router,
|
||||
# set ENABLE_V1_KVCACHE_SCHEDULER is 1, use rdma to transfer cache,
|
||||
# the tp_size of prefill is 2 and the tp_size of decode is 1.
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from utils.serving_utils import (
|
||||
FD_API_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
clean,
|
||||
get_registered_number,
|
||||
)
|
||||
|
||||
# Read ports from environment variables; use default values if not set
|
||||
FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433))
|
||||
FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533))
|
||||
FD_RDMA_PORT = int(os.getenv("FD_RDMA_PORT", 8623))
|
||||
|
||||
# List of ports to clean before and after tests
|
||||
PORTS_TO_CLEAN = [
|
||||
FD_API_PORT,
|
||||
FD_ENGINE_QUEUE_PORT,
|
||||
FD_METRICS_PORT,
|
||||
FD_CACHE_QUEUE_PORT,
|
||||
FD_CONNECTOR_PORT,
|
||||
FD_RDMA_PORT,
|
||||
FD_RDMA_PORT + 1,
|
||||
FD_API_PORT + 1,
|
||||
FD_ENGINE_QUEUE_PORT + 1,
|
||||
FD_METRICS_PORT + 1,
|
||||
FD_CACHE_QUEUE_PORT + 1,
|
||||
FD_CONNECTOR_PORT + 1,
|
||||
FD_RDMA_PORT + 2,
|
||||
FD_ROUTER_PORT,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def setup_and_run_server():
|
||||
"""
|
||||
Pytest fixture that runs once per test session:
|
||||
- Cleans ports before tests
|
||||
- Starts the API server as a subprocess
|
||||
- Waits for server port to open (up to 30 seconds)
|
||||
- Tears down server after all tests finish
|
||||
"""
|
||||
print("Pre-test port cleanup...")
|
||||
clean(PORTS_TO_CLEAN)
|
||||
|
||||
print("log dir clean ")
|
||||
if os.path.exists("log_router") and os.path.isdir("log_router"):
|
||||
shutil.rmtree("log_router")
|
||||
if os.path.exists("log_prefill") and os.path.isdir("log_prefill"):
|
||||
shutil.rmtree("log_prefill")
|
||||
if os.path.exists("log_decode") and os.path.isdir("log_decode"):
|
||||
shutil.rmtree("log_decode")
|
||||
|
||||
base_path = os.getenv("MODEL_PATH")
|
||||
if base_path:
|
||||
model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
|
||||
else:
|
||||
model_path = "baidu/ERNIE-4.5-0.3B-Paddle"
|
||||
print(f"model_path: {model_path}")
|
||||
|
||||
# get rdma nics
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
shell_path = os.path.join(current_dir, "utils/get_rdma_nics.sh")
|
||||
output = subprocess.check_output(["bash", shell_path, "gpu"], text=True)
|
||||
_, rdma_nics = output.split("=")
|
||||
print(f"shell_path: {shell_path}, rdma_nics: {rdma_nics}")
|
||||
|
||||
# router
|
||||
print("start router...")
|
||||
env_router = os.environ.copy()
|
||||
env_router["FD_LOG_DIR"] = "log_router"
|
||||
router_log_path = "router.log"
|
||||
|
||||
router_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.router.launch",
|
||||
"--port",
|
||||
str(FD_ROUTER_PORT),
|
||||
"--splitwise",
|
||||
]
|
||||
|
||||
with open(router_log_path, "w") as logfile:
|
||||
process_router = subprocess.Popen(
|
||||
router_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_router,
|
||||
)
|
||||
|
||||
# prefill实例
|
||||
print("start prefill...")
|
||||
env_prefill = os.environ.copy()
|
||||
env_prefill["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
env_prefill["FD_LOG_DIR"] = "log_prefill"
|
||||
env_prefill["KVCACHE_RDMA_NICS"] = rdma_nics
|
||||
|
||||
prefill_log_path = "prefill.log"
|
||||
prefill_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--tensor-parallel-size",
|
||||
"2",
|
||||
"--splitwise-role",
|
||||
"prefill",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
f"{FD_RDMA_PORT},{FD_RDMA_PORT+1}",
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT),
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(prefill_log_path, "w") as logfile:
|
||||
process_prefill = subprocess.Popen(
|
||||
prefill_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_prefill,
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# decode实例
|
||||
print("start decode...")
|
||||
env_decode = os.environ.copy()
|
||||
env_decode["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
env_decode["FD_LOG_DIR"] = "log_decode"
|
||||
env_decode["KVCACHE_RDMA_NICS"] = rdma_nics
|
||||
|
||||
decode_log_path = "decode.log"
|
||||
decode_cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"fastdeploy.entrypoints.openai.api_server",
|
||||
"--model",
|
||||
model_path,
|
||||
"--port",
|
||||
str(FD_API_PORT + 1),
|
||||
"--engine-worker-queue-port",
|
||||
str(FD_ENGINE_QUEUE_PORT + 1),
|
||||
"--metrics-port",
|
||||
str(FD_METRICS_PORT + 1),
|
||||
"--cache-queue-port",
|
||||
str(FD_CACHE_QUEUE_PORT + 1),
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--splitwise-role",
|
||||
"decode",
|
||||
"--cache-transfer-protocol",
|
||||
"rdma",
|
||||
"--rdma-comm-ports",
|
||||
str(FD_RDMA_PORT + 2),
|
||||
"--pd-comm-port",
|
||||
str(FD_CONNECTOR_PORT + 1),
|
||||
"--router",
|
||||
f"0.0.0.0:{FD_ROUTER_PORT}",
|
||||
]
|
||||
|
||||
# Start subprocess in new process group
|
||||
with open(decode_log_path, "w") as logfile:
|
||||
process_decode = subprocess.Popen(
|
||||
decode_cmd,
|
||||
stdout=logfile,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True, # Enables killing full group via os.killpg
|
||||
env=env_decode,
|
||||
)
|
||||
|
||||
# Wait up to 300 seconds for API server to be ready
|
||||
for _ in range(60):
|
||||
registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}")
|
||||
if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1:
|
||||
print("Prefill and decode servers are both online")
|
||||
break
|
||||
time.sleep(5)
|
||||
else:
|
||||
print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...")
|
||||
try:
|
||||
os.killpg(process_router.pid, signal.SIGTERM)
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
except Exception as e:
|
||||
print(f"Failed to kill process group: {e}")
|
||||
raise RuntimeError(f"API server did not start on port {FD_API_PORT}")
|
||||
|
||||
yield # Run tests
|
||||
|
||||
print("\n===== Post-test server cleanup... =====")
|
||||
try:
|
||||
os.killpg(process_router.pid, signal.SIGTERM)
|
||||
os.killpg(process_prefill.pid, signal.SIGTERM)
|
||||
os.killpg(process_decode.pid, signal.SIGTERM)
|
||||
clean(PORTS_TO_CLEAN)
|
||||
print(f"Prefill server (pid={process_prefill.pid}) terminated")
|
||||
print(f"Decode server (pid={process_decode.pid}) terminated")
|
||||
except Exception as e:
|
||||
print(f"Failed to terminate API server: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def api_url(request):
|
||||
"""
|
||||
Returns the API endpoint URL for chat completions.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def metrics_url(request):
|
||||
"""
|
||||
Returns the metrics endpoint URL.
|
||||
"""
|
||||
return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def headers():
|
||||
"""
|
||||
Returns common HTTP request headers.
|
||||
"""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
def test_metrics_config(metrics_url):
|
||||
timeout = 600
|
||||
url = metrics_url.replace("metrics", "config-info")
|
||||
res = requests.get(url, timeout=timeout)
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def send_request(url, payload, timeout=600):
|
||||
"""
|
||||
发送请求到指定的URL,并返回响应结果。
|
||||
"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
res = requests.post(url, headers=headers, json=payload, timeout=timeout)
|
||||
print("🟢 接收响应中...\n")
|
||||
return res
|
||||
except requests.exceptions.Timeout:
|
||||
print(f"❌ 请求超时(超过 {timeout} 秒)")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"❌ 请求失败:{e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_chunks(response):
|
||||
"""解析流式返回,生成chunk List[dict]"""
|
||||
chunks = []
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines(decode_unicode=True):
|
||||
if line:
|
||||
if line.startswith("data: "):
|
||||
line = line[len("data: ") :]
|
||||
|
||||
if line.strip() == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
chunk = json.loads(line)
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
print(f"解析失败: {e}, 行内容: {line}")
|
||||
else:
|
||||
print(f"请求失败,状态码: {response.status_code}")
|
||||
print("返回内容:", response.text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def test_chat_usage_stream(api_url):
|
||||
"""测试流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_chat_usage_non_stream(api_url):
|
||||
"""测试非流式chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "牛顿的三大运动定律是什么?"},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["message"]["content"]
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_stream(api_url):
|
||||
"""测试流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True, "continuous_usage_stats": True},
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload)
|
||||
chunks = get_stream_chunks(response)
|
||||
result = "".join([x["choices"][0]["text"] for x in chunks[:-1]])
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
usage = chunks[-1]["usage"]
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
|
||||
|
||||
def test_non_chat_usage_non_stream(api_url):
|
||||
"""测试非流式非chat usage"""
|
||||
payload = {
|
||||
"model": "default",
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"seed": 33,
|
||||
"prompt": "牛顿的三大运动定律是什么?",
|
||||
"max_tokens": 50,
|
||||
"stream": False,
|
||||
"metadata": {"min_tokens": 10},
|
||||
}
|
||||
api_url = api_url.replace("chat/completions", "completions")
|
||||
|
||||
response = send_request(url=api_url, payload=payload).json()
|
||||
usage = response["usage"]
|
||||
result = response["choices"][0]["text"]
|
||||
print("Decode Response:", result)
|
||||
assert result != "", "结果为空"
|
||||
total_tokens = usage["completion_tokens"] + usage["prompt_tokens"]
|
||||
assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens"
|
||||
assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens"
|
||||
assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens"
|
||||
Reference in New Issue
Block a user