mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Bug fix] Robust cache messager send cache when send cache slower than prefill (#4659)
This commit is contained in:
@@ -564,14 +564,14 @@ class CacheMessagerV1:
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
|
||||
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
|
||||
self.engine_worker_queue.finish_request_barrier.wait()
|
||||
block_start_end_list = []
|
||||
current_prefilled_token_num_list = []
|
||||
for engine_index in engine_indexes:
|
||||
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
|
||||
assert engine_index in self.idx_cache_task_dict
|
||||
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
|
||||
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
|
||||
prefilled_token_num = current_step_prefilled_token_num
|
||||
if (
|
||||
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
|
||||
): # all chunks have been prefilled
|
||||
@@ -581,17 +581,19 @@ class CacheMessagerV1:
|
||||
block_start_end_list.append((block_id_start, block_id_end))
|
||||
current_prefilled_token_num_list.append(prefilled_token_num)
|
||||
while True: # from layer0 to last layer
|
||||
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
|
||||
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
|
||||
start_layer_idx = sended_layer_idx + 1
|
||||
with self.engine_cache_task_thread_lock: # to check end_layer_idx
|
||||
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
|
||||
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
|
||||
"prefilled_layer_idx"
|
||||
]
|
||||
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
|
||||
logger.info(
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
|
||||
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
|
||||
)
|
||||
assert (
|
||||
current_prefilled_token_num_list[0]
|
||||
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
|
||||
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
|
||||
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
|
||||
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
|
||||
else:
|
||||
@@ -600,7 +602,7 @@ class CacheMessagerV1:
|
||||
time.sleep(0.01)
|
||||
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
|
||||
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
|
||||
engine_index = engine_indexes[i]
|
||||
engine_index = batch_engine_signals[i][0]
|
||||
task = self.idx_cache_task_dict[engine_index]
|
||||
req_id = task["request_id"]
|
||||
if (
|
||||
@@ -675,7 +677,7 @@ class CacheMessagerV1:
|
||||
task["sended_layer_id"] = -1
|
||||
if end_layer_idx == self.num_layers - 1:
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for engine_idx in engine_indexes:
|
||||
for engine_idx, _ in batch_engine_signals:
|
||||
task = self.idx_cache_task_dict[engine_idx]
|
||||
if task["status"] == "finished" or ("error" in task["status"]):
|
||||
target_id = int(task["rdma_ports"][self.rank])
|
||||
@@ -711,7 +713,8 @@ class CacheMessagerV1:
|
||||
layer_id = kv_signal_data[1].numpy().tolist()
|
||||
if layer_id == self.num_layers - 1:
|
||||
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
|
||||
batch_engine_ids = []
|
||||
batch_engine_signals = []
|
||||
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
|
||||
with self.engine_cache_task_thread_lock:
|
||||
for bi in range(tasks_count):
|
||||
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
|
||||
@@ -721,9 +724,9 @@ class CacheMessagerV1:
|
||||
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
|
||||
chuck_token_offset + current_seq_len
|
||||
)
|
||||
batch_engine_ids.append(engine_idx)
|
||||
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
|
||||
if layer_id == 0:
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
|
||||
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
|
||||
except Exception as e:
|
||||
logger.error(f"Consume signals get exception: {e}")
|
||||
|
||||
|
||||
@@ -596,31 +596,47 @@ class EngineService:
|
||||
batch=num_prefill_batch,
|
||||
)
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
need_delete_tasks = []
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
while True:
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
|
||||
time.sleep(0.05)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
for task in tasks:
|
||||
# assure can allocate block ids in P
|
||||
while not self.resource_manager.preallocate_resource_in_p(task):
|
||||
time.sleep(0.005)
|
||||
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
|
||||
self.split_connector.send_splitwise_tasks([task], task.idx)
|
||||
|
||||
for task in tasks:
|
||||
if self.cfg.scheduler_config.splitwise_role != "mixed":
|
||||
# assure fetch block ids from D
|
||||
status, msg = self.split_connector.check_decode_allocated(task)
|
||||
if not status:
|
||||
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
|
||||
self.scheduler.put_results(
|
||||
[
|
||||
RequestOutput(
|
||||
request_id=task.request_id,
|
||||
finished=True,
|
||||
error_code=500,
|
||||
error_msg=msg,
|
||||
)
|
||||
]
|
||||
)
|
||||
need_delete_tasks.append(task)
|
||||
continue
|
||||
for tmp_task in need_delete_tasks:
|
||||
tasks.remove(tmp_task)
|
||||
# release resource in P
|
||||
@@ -887,6 +903,7 @@ class EngineService:
|
||||
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
|
||||
)
|
||||
continue
|
||||
self.token_processor.tokens_counter[task.request_id] = 1
|
||||
self.resource_manager.insert_task_for_decoding(task)
|
||||
|
||||
else:
|
||||
|
||||
@@ -122,6 +122,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"FD_ENABLE_METRIC_LABELS": lambda: bool(int(os.getenv("FD_ENABLE_METRIC_LABELS", "0"))),
|
||||
# Default label values in metrics.
|
||||
"FD_DEFAULT_METRIC_LABEL_VALUES": lambda: os.getenv("FD_DEFAULT_METRIC_LABEL_VALUES", "{}"),
|
||||
# Enable offline perf test mode for PD disaggregation
|
||||
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -275,6 +275,7 @@ class SplitwiseConnector:
|
||||
decode_diagg = task.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
|
||||
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
|
||||
task.disaggregate_info["role"] = "decode"
|
||||
self._send_message(addr, "prefill", [task])
|
||||
task.disaggregate_info["cache_info"] = decode_diagg
|
||||
task.disaggregate_info["role"] = "prefill"
|
||||
|
||||
Reference in New Issue
Block a user