mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[PD Disaggregation] support different tp_size for prefill and decode (#5296)
* up * up * up * fix
This commit is contained in:
@@ -55,13 +55,13 @@ def parse_args():
|
||||
default="mixed",
|
||||
help="splitwise role, can be decode, prefill or mixed",
|
||||
)
|
||||
parser.add_argument("--rank", type=int, default=0, help="current rank")
|
||||
parser.add_argument("--rank", type=int, default=0, help="local tp rank id")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id")
|
||||
parser.add_argument("--num_layers", type=int, default=1, help="model num layers")
|
||||
parser.add_argument("--key_cache_shape", type=str, default="", help="key cache shape")
|
||||
parser.add_argument("--value_cache_shape", type=str, default="", help="value cache shape")
|
||||
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
|
||||
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel, i.e. tp_size, tp_num")
|
||||
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
|
||||
parser.add_argument(
|
||||
"--protocol",
|
||||
@@ -208,6 +208,8 @@ class CacheMessager:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -507,6 +509,8 @@ class CacheMessagerV1:
|
||||
max_block_num,
|
||||
block_bytes,
|
||||
rdma_port,
|
||||
nranks,
|
||||
rank,
|
||||
)
|
||||
|
||||
self.gpu_id = gpu_id
|
||||
@@ -595,6 +599,7 @@ class CacheMessagerV1:
|
||||
block_id_end = prefilled_token_num // self.block_size # [block_id_start, block_id_end)
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
@@ -633,13 +638,27 @@ class CacheMessagerV1:
|
||||
current_transfer_protocol = task["transfer_protocol"]
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
# Default decode_tp_size to prefill tp_size (self.nranks) if not specified
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
if len(task["rdma_ports"]) == self.nranks:
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
elif len(task["rdma_ports"]) == 1:
|
||||
target_id = task["rdma_ports"][0]
|
||||
else:
|
||||
task["status"] = "the tp_size of prefill and decode is mismatch"
|
||||
continue
|
||||
|
||||
if "error" in task["status"]:
|
||||
continue
|
||||
|
||||
# TODO: use is connected to check if the connection is still alive
|
||||
logger.debug(f"rdma, start connect decode, {target_ip}:{target_id}")
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
logger.debug(
|
||||
f"rdma, start connect decode, {target_ip}:{target_id}, "
|
||||
f"prefill_tp_size:{self.nranks}, decode_tp_size:{decode_tp_size}"
|
||||
)
|
||||
status = self.messager[current_transfer_protocol].connect(
|
||||
target_ip, target_id, decode_tp_size
|
||||
)
|
||||
if status:
|
||||
logger.info(f"connect to {target_ip}:{target_id} success")
|
||||
else:
|
||||
@@ -762,12 +781,22 @@ class CacheMessagerV1:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
ip = task["ip"]
|
||||
# Default decode_tp_size to self.nranks (number of ranks) if not specified in the task.
|
||||
decode_tp_size = task.get("decode_tp_size", self.nranks)
|
||||
rdma_ports = task["rdma_ports"]
|
||||
rdma_ports_len = len(rdma_ports)
|
||||
if not (rdma_ports_len == 1 or rdma_ports_len == self.nranks):
|
||||
# TODO: support other cases
|
||||
logger.error(f"rdma_ports length should be 1 or equal to mp_num, but got {rdma_ports_len}")
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
port = rdma_ports[0] if rdma_ports_len == 1 else rdma_ports[self.rank]
|
||||
status = self.messager["rdma"].connect(ip, port, decode_tp_size)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user