[Optimize] Support and robust for tpN for PD (#4595)

* [Optimize] Support and robust for tpN for PD

* fix

* fix

* support dpM tpN for cache messager

* fix

* fix token counter

* fix bug for merge develop

* fix bug

* robust cache messager for v0
This commit is contained in:
chenjian
2025-11-03 15:38:31 +08:00
committed by GitHub
parent 7b35488779
commit 25498efcf3
9 changed files with 452 additions and 197 deletions

View File

@@ -246,11 +246,10 @@ class CacheMessager:
engine_recycled_count = 0
while True:
cache_info = self.engine_worker_queue.get_cache_info()
if cache_info:
logger.debug(f"cache info {cache_info}")
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
@@ -295,9 +294,6 @@ class CacheMessager:
continue
if "layer_idx" not in item:
item["layer_idx"] = 0
if item["status"] == "error":
del self.cache_info[req_id]
continue
if item["current_id"] > prefilled_step_idx:
continue
current_transfer_protocol = item["transfer_protocol"]
@@ -307,11 +303,7 @@ class CacheMessager:
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
logger.error(f"connect to {target_ip}:{target_id} failed")
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
continue
item["status"] = "connect error"
elif item["transfer_protocol"] == "ipc":
target_ip = "0.0.0.0"
target_id = int(item["device_ids"][self.rank])
@@ -321,48 +313,43 @@ class CacheMessager:
current_layer_idx = self.num_layers
else:
current_layer_idx = prefilled_layer_idx + 1
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
if return_code != 0:
item["status"] = "error"
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
if "error" not in item["status"]:
for layer_idx in range(item["layer_idx"], current_layer_idx):
tic = time.time()
return_code = self.messager[current_transfer_protocol].write_cache(
target_ip,
target_id,
src_block_ids,
dest_block_ids,
layer_idx,
)
break
if return_code != 0:
item["status"] = "write cache error"
logger.error(
f"write cache failed, layer_idx: {layer_idx}, "
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
)
break
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
tok = time.time()
cost_time = tok - tic
block_num = len(src_block_ids)
avg_time_per_block = cost_time * 1000 / block_num # ms
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
logger.debug(
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
f" {current_transfer_protocol}"
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
)
item["layer_idx"] = current_layer_idx
if item["layer_idx"] == self.num_layers:
if item["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
logger.info(f"finish write cache {item['request_id']}")
self.engine_worker_queue.finish_request_barrier.wait()
if self.rank == 0:
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
logger.info(f"put write cache {item['request_id']}")
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
logger.info(f"put write cache {item['request_id']}, status {item['status']}")
del self.cache_info[req_id]
self.last_layer_idx = prefilled_layer_idx
@@ -376,14 +363,17 @@ class CacheMessager:
if task is None:
time.sleep(0.001)
continue
else:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")
@@ -524,9 +514,9 @@ class CacheMessagerV1:
while True:
try:
cache_info = self.engine_worker_queue.get_cache_info()
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
finished_add_cache_task_req_ids = []
if cache_info:
self.engine_worker_queue.cache_info_barrier.wait()
for info in cache_info:
if info["request_id"] in self.cache_info:
self.cache_info[info["request_id"]].update(info)
@@ -544,13 +534,16 @@ class CacheMessagerV1:
current_info["sended_layer_id"] = -1
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
current_info["status"] = "init"
logger.info(f"finish add cache task: {current_info}")
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
self.cache_info[info["request_id"]] = current_info
self.idx_cache_task_dict[current_info["current_id"]] = current_info
else:
logger.info(f"Get cache info from D: {info}")
self.cache_info[info["request_id"]] = info
if self.rank == 0 and finished_add_cache_task_req_ids:
if finished_add_cache_task_req_ids:
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
else:
time.sleep(0.001)
except Exception as e:
@@ -563,14 +556,16 @@ class CacheMessagerV1:
"""
while True:
try:
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.finish_request_barrier.wait()
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.begin_send_cache_barrier.wait()
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index in engine_indexes:
assert engine_index in self.idx_cache_task_dict
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
assert (
engine_index in self.idx_cache_task_dict
), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}"
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
prefilled_token_num = current_step_prefilled_token_num
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
): # all chunks have been prefilled
@@ -580,17 +575,20 @@ class CacheMessagerV1:
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
with self.engine_cache_task_thread_lock: # to check end_layer_idx
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
"prefilled_layer_idx"
]
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
logger.info(
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
)
assert (
current_prefilled_token_num_list[0]
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
else:
@@ -599,7 +597,7 @@ class CacheMessagerV1:
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = engine_indexes[i]
engine_index = batch_engine_signals[i][0]
task = self.idx_cache_task_dict[engine_index]
req_id = task["request_id"]
if (
@@ -615,7 +613,7 @@ class CacheMessagerV1:
if task["transfer_protocol"] == "rdma":
target_ip = task["ip"]
target_id = int(task["rdma_ports"][self.rank])
if task["status"] == "error":
if "error" in task["status"]:
continue
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
if not status:
@@ -665,7 +663,7 @@ class CacheMessagerV1:
block_id_end - block_id_start
)
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
if task["status"] != "error":
if "error" not in task["status"]:
task["status"] = "finished"
logger.info(
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
@@ -674,18 +672,15 @@ class CacheMessagerV1:
task["sended_layer_id"] = -1
if end_layer_idx == self.num_layers - 1:
with self.engine_cache_task_thread_lock:
for engine_idx in engine_indexes:
for engine_idx, _ in batch_engine_signals:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
target_id = int(task["rdma_ports"][self.rank])
if task["transfer_protocol"] == "ipc":
self.messager["ipc"].write_block_by_sync(target_id)
if self.rank == 0:
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
self.engine_worker_queue.put_finished_req(
[(task["request_id"], task["status"])]
)
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_worker_queue.finish_send_cache_barrier.wait()
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
self.engine_cache_tasks[task["current_id"]] = dict()
del self.cache_info[task["request_id"]]
del self.idx_cache_task_dict[task["current_id"]]
@@ -709,8 +704,9 @@ class CacheMessagerV1:
continue
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
batch_engine_ids = []
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
batch_engine_signals = []
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
@@ -720,27 +716,33 @@ class CacheMessagerV1:
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_ids.append(engine_idx)
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
if layer_id == 0:
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
logger.info(
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
)
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
except Exception as e:
logger.error(f"Consume signals get exception: {e}")
def _handle_connect_task(self):
while True:
try:
task = self.engine_worker_queue.get_connect_rdma_task()
task, _ = self.engine_worker_queue.get_connect_rdma_task()
if task is None:
time.sleep(0.001)
continue
else:
self.engine_worker_queue.connect_task_barrier.wait()
logger.info(f"_handle_connect_task recv task: {task}")
task_id = task["task_id"]
ip, rdma_port = task["ip"], task["rdma_port"]
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
status = self.messager["rdma"].connect(ip, rdma_port)
if not status:
response = {"task_id": task_id, "success": False}
else:
response = {"task_id": task_id, "success": True}
self.engine_worker_queue.connect_task_response_barrier.wait()
self.engine_worker_queue.put_connect_rdma_task_response(response)
except Exception as e:
logger.error(f"handle_connect_task has exception: {e}")