mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
This commit is contained in:
@@ -55,13 +55,13 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank id")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
@@ -208,6 +208,8 @@ class CacheMessager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -507,6 +509,8 @@ class CacheMessagerV1:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -595,6 +599,7 @@ class CacheMessagerV1:
|
||||
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
@@ -633,13 +638,27 @@ class CacheMessagerV1:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
if len(task["rdma_ports"]) == self.nranks:
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
elif len(task["rdma_ports"]) == 1:
|
||||
target_id = task["rdma_ports"][0]
|
||||
else:
|
||||
task["status"] = "the tp_size of prefill and decode is mismatch"
|
||||
continue
|
||||
|
||||
if "error" in task["status"]:
|
||||
continue
|
||||
|
||||
# TODO: use is connected to check if the connection is still alive
|
||||
logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}")
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
logger.debug(
|
||||
f"rdma, start connect decode, {target_ip}:{target_id}, "
|
||||
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
|
||||
)
|
||||
status = self.messager[current_transfer_protocol].connect(
|
||||
target_ip, target_id, decode_tp_size
|
||||
)
|
||||
if status:
|
||||
logger.info(f"connect to {target_ip}:{target_id} success")
|
||||
else:
|
||||
@@ -762,12 +781,22 @@ class CacheMessagerV1:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
ip = task["ip"]
|
||||
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
rdma_ports = task["rdma_ports"]
|
||||
rdma_ports_len = len(rdma_ports)
|
||||
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
|
||||
# TODO: support other cases
|
||||
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
|
||||
@@ -142,6 +142,7 @@ struct Connection {
|
||||
int wc_target_count;
|
||||
|
||||
// Configuration
|
||||
int decode_tp_size;
|
||||
int layer_number;
|
||||
int block_number;
|
||||
int block_byte_size;
|
||||
|
||||
@@ -24,11 +24,15 @@ class RDMACommunicator {
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes);
|
||||
int block_bytes,
|
||||
int prefill_tp_size,
|
||||
int prefill_tp_idx);
|
||||
~RDMACommunicator();
|
||||
|
||||
// Connection management
|
||||
int connect(const std::string& dst_ip, const std::string& dst_port);
|
||||
int connect(const std::string& dst_ip,
|
||||
const std::string& dst_port,
|
||||
int dest_tp_size);
|
||||
bool is_connected(const std::string& dst_ip, const std::string& dst_port);
|
||||
|
||||
// Core functionality
|
||||
@@ -120,6 +124,8 @@ class RDMACommunicator {
|
||||
int block_number; // Number of blocks
|
||||
int block_size_byte; // Size of each block in bytes
|
||||
int layer_number; // Number of layers
|
||||
int prefill_tp_size; // tensor parallelism size for prefill
|
||||
int prefill_tp_idx; // tensor parallelism index for prefill
|
||||
|
||||
std::vector<std::vector<void*>>
|
||||
local_cache_key_ptr_per_layer; // Per-layer key pointers
|
||||
|
||||
@@ -41,7 +41,7 @@
|
||||
* @param local_key_cache Vector of local key cache pointers
|
||||
* @param local_value_cache Vector of local value cache pointers
|
||||
* @param block_number Number of blocks in cache
|
||||
* @param block_bytes Size of each block in bytes
|
||||
* @param block_bytes Bytes of each block in each tp rank
|
||||
*
|
||||
* @throws std::runtime_error If initialization fails
|
||||
*/
|
||||
@@ -51,7 +51,9 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
std::vector<int64_t> local_key_cache,
|
||||
std::vector<int64_t> local_value_cache,
|
||||
int block_number,
|
||||
int block_bytes)
|
||||
int block_bytes,
|
||||
int prefill_tp_size,
|
||||
int prefill_tp_idx)
|
||||
: splitwise_role(role),
|
||||
gpu_idx(gpu_idx),
|
||||
port(port),
|
||||
@@ -59,6 +61,8 @@ RDMACommunicator::RDMACommunicator(std::string& role,
|
||||
local_cache_value_ptr_layer_head_(std::move(local_value_cache)),
|
||||
block_number(block_number),
|
||||
block_size_byte(block_bytes),
|
||||
prefill_tp_size(prefill_tp_size),
|
||||
prefill_tp_idx(prefill_tp_idx),
|
||||
RDMACommunicator_status(0),
|
||||
rdma_event_channel_epoll_fd(-1) {
|
||||
try {
|
||||
@@ -480,11 +484,14 @@ std::string RDMACommunicator::fetch_local_ip() {
|
||||
*
|
||||
* @param dst_ip Destination IP address
|
||||
* @param dst_port Destination port
|
||||
* @param dest_tp_size Default 0: assumes dest has same tp_size as source;
|
||||
* otherwise specifies decode tp_size
|
||||
* @return ConnStatus::kConnected ConnStatus::kError;
|
||||
*/
|
||||
|
||||
int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
const std::string& dst_port,
|
||||
int dest_tp_size = 0) {
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
|
||||
// Initialize IB devices if not already done
|
||||
@@ -515,6 +522,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ctx->conn.layer_number = layer_number;
|
||||
ctx->conn.block_number = block_number;
|
||||
ctx->conn.block_byte_size = block_size_byte;
|
||||
if (dest_tp_size > 0)
|
||||
ctx->conn.decode_tp_size = dest_tp_size;
|
||||
else
|
||||
ctx->conn.decode_tp_size = prefill_tp_size;
|
||||
|
||||
// Get port information for the connection
|
||||
if (get_port_info(ctx->context, ib_dev->port, &ctx->portinfo)) {
|
||||
@@ -537,9 +548,6 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
ERR("Couldn't getexchange port infodestinations");
|
||||
return static_cast<int>(ConnStatus::kError);
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
client_exchange_mr(ctx);
|
||||
}
|
||||
|
||||
@@ -589,6 +597,10 @@ int RDMACommunicator::connect(const std::string& dst_ip,
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
ctx->conn.connected = 1;
|
||||
conn_map[url] = ctx;
|
||||
|
||||
WARN("connect end ....");
|
||||
return static_cast<int>(ConnStatus::kConnected);
|
||||
}
|
||||
@@ -649,6 +661,7 @@ int RDMACommunicator::client_listener() {
|
||||
|
||||
bool RDMACommunicator::is_connected(const std::string& dst_ip,
|
||||
const std::string& dst_port) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
std::string url = dst_ip + ":" + dst_port;
|
||||
return conn_map.find(url) != conn_map.end();
|
||||
}
|
||||
@@ -889,17 +902,25 @@ int RDMACommunicator::write_cache(const std::string& ip,
|
||||
uint32_t cache_value_rkey =
|
||||
ctx->conn.write_cache_value_remote_rkey_list[layer_idx];
|
||||
uint32_t crc_cache_key_rkey, crc_cache_value_rkey;
|
||||
bool pd_tp_size_is_same = prefill_tp_size == ctx->conn.decode_tp_size;
|
||||
uint64_t offset_in_block =
|
||||
pd_tp_size_is_same ? 0 : block_size_byte * prefill_tp_idx;
|
||||
uint64_t total_block_size_byte =
|
||||
pd_tp_size_is_same ? block_size_byte : block_size_byte * prefill_tp_size;
|
||||
|
||||
for (size_t block_index = 0; block_index < block_num; ++block_index) {
|
||||
char* char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_key_remote_ptr_list[layer_idx]);
|
||||
cache_key_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_key_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
char_ptr = static_cast<char*>(
|
||||
ctx->conn.write_cache_value_remote_ptr_list[layer_idx]);
|
||||
cache_value_remote_addr[block_index] =
|
||||
(uint64_t(char_ptr + remote_block_ids[block_index] * block_size_byte));
|
||||
cache_value_remote_addr[block_index] = (uint64_t(
|
||||
char_ptr + remote_block_ids[block_index] * total_block_size_byte +
|
||||
offset_in_block));
|
||||
}
|
||||
|
||||
ctx->conn.wc_target_count = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
bool is_key = (i == 0);
|
||||
|
||||
@@ -14,10 +14,39 @@ PYBIND11_MODULE(rdma_comm, m) {
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int,
|
||||
int>())
|
||||
.def("connect", &RDMACommunicator::connect)
|
||||
.def("is_connected", &RDMACommunicator::is_connected)
|
||||
.def("write_cache", &RDMACommunicator::write_cache);
|
||||
int,
|
||||
int,
|
||||
int>(),
|
||||
py::arg("splitwise_role"),
|
||||
py::arg("gpu_idx"),
|
||||
py::arg("port"),
|
||||
py::arg("key_cache_ptrs"),
|
||||
py::arg("value_cache_ptrs"),
|
||||
py::arg("block_number"),
|
||||
py::arg("block_bytes"),
|
||||
py::arg("prefill_tp_size") = 1,
|
||||
py::arg("prefill_tp_idx") = 0)
|
||||
.def("connect",
|
||||
&RDMACommunicator::connect,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("dst_tp_size") =
|
||||
0, // Default 0: assumes dest has same tp_size as source;
|
||||
// otherwise specifies decode tp_size
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_connected",
|
||||
&RDMACommunicator::is_connected,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("write_cache",
|
||||
&RDMACommunicator::write_cache,
|
||||
py::arg("dst_ip"),
|
||||
py::arg("dst_port"),
|
||||
py::arg("local_block_ids"),
|
||||
py::arg("remote_block_ids"),
|
||||
py::arg("layer_idx"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
|
||||
#ifdef VERSION_INFO
|
||||
m.attr("__version__") = VERSION_INFO;
|
||||
|
||||
@@ -34,6 +34,8 @@ class RDMACommManager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
):
|
||||
try:
|
||||
import rdma_comm
|
||||
@@ -51,12 +53,16 @@ class RDMACommManager:
|
||||
cache_v_ptr_list,
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
prefill_tp_size,
|
||||
prefill_tp_idx,
|
||||
)
|
||||
self.splitwise_role = splitwise_role
|
||||
self.connected_rdma = set()
|
||||
logger.info(f"init rdma messager {gpu_id} {rdma_port}")
|
||||
logger.info(
|
||||
f"init rdma messager {gpu_id} {rdma_port}, prefill_tp_size: {prefill_tp_size}, prefill_tp_idx: {prefill_tp_idx}"
|
||||
)
|
||||
|
||||
def connect(self, ip, port):
|
||||
def connect(self, ip, port, tp_size):
|
||||
"""
|
||||
Connect to remote gpu and write cache.
|
||||
"""
|
||||
@@ -65,7 +71,7 @@ class RDMACommManager:
|
||||
if ret:
|
||||
return True
|
||||
|
||||
ret = self.messager.connect(ip, str(port))
|
||||
ret = self.messager.connect(ip, str(port), tp_size)
|
||||
logger.info(f"connect to remote rdma address {ip}:{port} status is {ret}")
|
||||
return ret == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user