mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Optimize] Support and robust for tpN for PD (#4595)
* [Optimize] Support and robust for tpN for PD * fix * fix * support dpM tpN for cache messager * fix * fix token counter * fix bug for merge develop * fix bug * robust cache messager for v0
This commit is contained in:
@@ -246,11 +246,10 @@ class CacheMessager:
|
||||
engine_recycled_count = 0
|
||||
|
||||
while True:
|
||||
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
|
||||
if cache_info:
|
||||
logger.debug(f"cache info {cache_info}")
|
||||
self.engine_worker_queue.cache_info_barrier.wait()
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -295,9 +294,6 @@ class CacheMessager:
|
||||
continue
|
||||
if "layer_idx" not in item:
|
||||
item["layer_idx"] = 0
|
||||
if item["status"] == "error":
|
||||
del self.cache_info[req_id]
|
||||
continue
|
||||
if item["current_id"] > prefilled_step_idx:
|
||||
continue
|
||||
current_transfer_protocol = item["transfer_protocol"]
|
||||
@@ -307,11 +303,7 @@ class CacheMessager:
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
logger.error(f"connect to {target_ip}:{target_id} failed")
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "connect error")])
|
||||
continue
|
||||
item["status"] = "connect error"
|
||||
elif item["transfer_protocol"] == "ipc":
|
||||
target_ip = "0.0.0.0"
|
||||
target_id = int(item["device_ids"][self.rank])
|
||||
@@ -321,48 +313,43 @@ class CacheMessager:
|
||||
current_layer_idx = self.num_layers
|
||||
else:
|
||||
current_layer_idx = prefilled_layer_idx + 1
|
||||
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
if return_code != 0:
|
||||
item["status"] = "error"
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
if "error" not in item["status"]:
|
||||
for layer_idx in range(item["layer_idx"], current_layer_idx):
|
||||
tic = time.time()
|
||||
return_code = self.messager[current_transfer_protocol].write_cache(
|
||||
target_ip,
|
||||
target_id,
|
||||
src_block_ids,
|
||||
dest_block_ids,
|
||||
layer_idx,
|
||||
)
|
||||
break
|
||||
if return_code != 0:
|
||||
item["status"] = "write cache error"
|
||||
logger.error(
|
||||
f"write cache failed, layer_idx: {layer_idx}, "
|
||||
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
|
||||
)
|
||||
break
|
||||
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
tok = time.time()
|
||||
cost_time = tok - tic
|
||||
block_num = len(src_block_ids)
|
||||
avg_time_per_block = cost_time * 1000 / block_num # ms
|
||||
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
|
||||
logger.debug(
|
||||
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
|
||||
f" {current_transfer_protocol}"
|
||||
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
|
||||
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
|
||||
)
|
||||
item["layer_idx"] = current_layer_idx
|
||||
if item["layer_idx"] == self.num_layers:
|
||||
if item["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
logger.info(f"finish write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP: here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req([(item["request_id"], "finished")])
|
||||
logger.info(f"put write cache {item['request_id']}")
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[item["request_id"], item["status"]]])
|
||||
logger.info(f"put write cache {item['request_id']}, status {item['status']}")
|
||||
del self.cache_info[req_id]
|
||||
self.last_layer_idx = prefilled_layer_idx
|
||||
|
||||
@@ -376,14 +363,17 @@ class CacheMessager:
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
@@ -524,9 +514,9 @@ class CacheMessagerV1:
|
||||
while True:
|
||||
try:
|
||||
cache_info = self.engine_worker_queue.get_cache_info()
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
finished_add_cache_task_req_ids = []
|
||||
if cache_info:
|
||||
self.engine_worker_queue.cache_info_barrier.wait()
|
||||
for info in cache_info:
|
||||
if info["request_id"] in self.cache_info:
|
||||
self.cache_info[info["request_id"]].update(info)
|
||||
@@ -544,13 +534,16 @@ class CacheMessagerV1:
|
||||
current_info["sended_layer_id"] = -1
|
||||
current_info["sended_block_num"] = current_info["decode_cached_tokens"] // self.block_size
|
||||
current_info["status"] = "init"
|
||||
logger.info(f"finish add cache task: {current_info}")
|
||||
logger.info(f"Get cache info from P: finish add cache task: {current_info}")
|
||||
self.cache_info[info["request_id"]] = current_info
|
||||
self.idx_cache_task_dict[current_info["current_id"]] = current_info
|
||||
else:
|
||||
logger.info(f"Get cache info from D: {info}")
|
||||
self.cache_info[info["request_id"]] = info
|
||||
if self.rank == 0 and finished_add_cache_task_req_ids:
|
||||
|
||||
if finished_add_cache_task_req_ids:
|
||||
self.engine_worker_queue.put_finished_add_cache_task_req(finished_add_cache_task_req_ids)
|
||||
self.engine_worker_queue.finish_add_cache_task_barrier.wait()
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
@@ -563,14 +556,16 @@ class CacheMessagerV1:
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.begin_send_cache_barrier.wait()
|
||||
block_start_end_list = []
|
||||
current_prefilled_token_num_list = []
|
||||
for engine_index in engine_indexes:
|
||||
assert engine_index in self.idx_cache_task_dict
|
||||
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
|
||||
assert (
|
||||
engine_index in self.idx_cache_task_dict
|
||||
), f"engine_index {engine_index} not in self.idx_cache_task_dict {self.idx_cache_task_dict}"
|
||||
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||
prefilled_token_num = current_step_prefilled_token_num
|
||||
if (
|
||||
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||
): # all chunks have been prefilled
|
||||
@@ -580,17 +575,20 @@ class CacheMessagerV1:
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
|
||||
"prefilled_layer_idx"
|
||||
]
|
||||
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||
logger.info(
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
|
||||
)
|
||||
|
||||
assert (
|
||||
current_prefilled_token_num_list[0]
|
||||
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
|
||||
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||
else:
|
||||
@@ -599,7 +597,7 @@ class CacheMessagerV1:
|
||||
time.sleep(0.01)
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = engine_indexes[i]
|
||||
engine_index = batch_engine_signals[i][0]
|
||||
task = self.idx_cache_task_dict[engine_index]
|
||||
req_id = task["request_id"]
|
||||
if (
|
||||
@@ -615,7 +613,7 @@ class CacheMessagerV1:
|
||||
if task["transfer_protocol"] == "rdma":
|
||||
target_ip = task["ip"]
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["status"] == "error":
|
||||
if "error" in task["status"]:
|
||||
continue
|
||||
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
|
||||
if not status:
|
||||
@@ -665,7 +663,7 @@ class CacheMessagerV1:
|
||||
block_id_end - block_id_start
|
||||
)
|
||||
if current_prefilled_token_num_list[i] == task["need_prefill_tokens"]:
|
||||
if task["status"] != "error":
|
||||
if "error" not in task["status"]:
|
||||
task["status"] = "finished"
|
||||
logger.info(
|
||||
f"finish write cache for all layers, req_id: {req_id}, block_id_end {block_id_end} need_prefill_tokens {task['need_prefill_tokens']}"
|
||||
@@ -674,18 +672,15 @@ class CacheMessagerV1:
|
||||
task["sended_layer_id"] = -1
|
||||
if end_layer_idx == self.num_layers - 1:
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for engine_idx in engine_indexes:
|
||||
for engine_idx, _ in batch_engine_signals:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
if task["transfer_protocol"] == "ipc":
|
||||
self.messager["ipc"].write_block_by_sync(target_id)
|
||||
if self.rank == 0:
|
||||
# to do: robust in TP, here we assume all status in tp are the same. If wrong, all wrong. If ok, all ok.
|
||||
self.engine_worker_queue.put_finished_req(
|
||||
[(task["request_id"], task["status"])]
|
||||
)
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_worker_queue.finish_send_cache_barrier.wait()
|
||||
self.engine_worker_queue.put_finished_req([[task["request_id"], task["status"]]])
|
||||
logger.info(f"put write cache {task['request_id']}, status {task['status']}")
|
||||
self.engine_cache_tasks[task["current_id"]] = dict()
|
||||
del self.cache_info[task["request_id"]]
|
||||
del self.idx_cache_task_dict[task["current_id"]]
|
||||
@@ -709,8 +704,9 @@ class CacheMessagerV1:
|
||||
continue
|
||||
layer_id = kv_signal_data[1].numpy().tolist()
|
||||
if layer_id == self.num_layers - 1:
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||
batch_engine_ids = []
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id} self.rank_id {self.rank_id}")
|
||||
batch_engine_signals = []
|
||||
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for bi in range(tasks_count):
|
||||
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||
@@ -720,27 +716,33 @@ class CacheMessagerV1:
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||
chuck_token_offset + current_seq_len
|
||||
)
|
||||
batch_engine_ids.append(engine_idx)
|
||||
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
|
||||
if layer_id == 0:
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||
logger.info(
|
||||
f"Put batch_engine_signals {batch_engine_signals} into cache_prefilled_engine_ids_queue"
|
||||
)
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
|
||||
except Exception as e:
|
||||
logger.error(f"Consume signals get exception: {e}")
|
||||
|
||||
def _handle_connect_task(self):
|
||||
while True:
|
||||
try:
|
||||
task = self.engine_worker_queue.get_connect_rdma_task()
|
||||
task, _ = self.engine_worker_queue.get_connect_rdma_task()
|
||||
if task is None:
|
||||
time.sleep(0.001)
|
||||
continue
|
||||
else:
|
||||
self.engine_worker_queue.connect_task_barrier.wait()
|
||||
logger.info(f"_handle_connect_task recv task: {task}")
|
||||
task_id = task["task_id"]
|
||||
ip, rdma_port = task["ip"], task["rdma_port"]
|
||||
ip, rdma_port = task["ip"], task["rdma_ports"][self.rank]
|
||||
status = self.messager["rdma"].connect(ip, rdma_port)
|
||||
if not status:
|
||||
response = {"task_id": task_id, "success": False}
|
||||
else:
|
||||
response = {"task_id": task_id, "success": True}
|
||||
self.engine_worker_queue.connect_task_response_barrier.wait()
|
||||
self.engine_worker_queue.put_connect_rdma_task_response(response)
|
||||
except Exception as e:
|
||||
logger.error(f"handle_connect_task has exception: {e}")
|
||||
|
||||
Reference in New Issue
Block a user