[Bug fix] Robust cache messager send cache when send cache slower than prefill (#4659)

This commit is contained in:
chenjian
2025-11-03 16:37:13 +08:00
committed by GitHub
parent 561a7ebc0b
commit aa7a926931
4 changed files with 59 additions and 36 deletions

View File

@@ -564,14 +564,14 @@ class CacheMessagerV1:
"""
while True:
try:
engine_indexes = self.cache_prefilled_engine_ids_queue.get()
batch_engine_signals = self.cache_prefilled_engine_ids_queue.get()
self.engine_worker_queue.finish_request_barrier.wait()
block_start_end_list = []
current_prefilled_token_num_list = []
for engine_index in engine_indexes:
for engine_index, current_step_prefilled_token_num in batch_engine_signals:
assert engine_index in self.idx_cache_task_dict
block_id_start = self.idx_cache_task_dict[engine_index]["sended_block_num"]
prefilled_token_num = self.engine_cache_tasks[engine_index]["prefilled_token_num"]
prefilled_token_num = current_step_prefilled_token_num
if (
prefilled_token_num == self.idx_cache_task_dict[engine_index]["need_prefill_tokens"]
): # all chunks have been prefilled
@@ -581,17 +581,19 @@ class CacheMessagerV1:
block_start_end_list.append((block_id_start, block_id_end))
current_prefilled_token_num_list.append(prefilled_token_num)
while True: # from layer0 to last layer
sended_layer_idx = self.idx_cache_task_dict[engine_indexes[0]]["sended_layer_id"]
sended_layer_idx = self.idx_cache_task_dict[batch_engine_signals[0][0]]["sended_layer_id"]
start_layer_idx = sended_layer_idx + 1
with self.engine_cache_task_thread_lock: # to check end_layer_idx
prefilled_layer_idx = self.engine_cache_tasks[engine_indexes[0]]["prefilled_layer_idx"]
prefilled_layer_idx = self.engine_cache_tasks[batch_engine_signals[0][0]][
"prefilled_layer_idx"
]
if sended_layer_idx > prefilled_layer_idx: # computation must in next chunk
logger.info(
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[engine_indexes[0]]['prefilled_token_num']}"
f"current_prefilled_token_num_list[0] {current_prefilled_token_num_list[0]} prefilled_token_num {self.engine_cache_tasks[batch_engine_signals[0][0]]['prefilled_token_num']}"
)
assert (
current_prefilled_token_num_list[0]
< self.engine_cache_tasks[engine_indexes[0]]["prefilled_token_num"]
< self.engine_cache_tasks[batch_engine_signals[0][0]]["prefilled_token_num"]
), "when sended_layer_idx > prefilled_layer_idx, must be in next chunk, but not, sth wrong"
end_layer_idx = self.num_layers - 1 # [start_layer_idx, end_layer_idx)
else:
@@ -600,7 +602,7 @@ class CacheMessagerV1:
time.sleep(0.01)
for layer_idx in range(start_layer_idx, end_layer_idx + 1):
for i, (block_id_start, block_id_end) in enumerate(block_start_end_list):
engine_index = engine_indexes[i]
engine_index = batch_engine_signals[i][0]
task = self.idx_cache_task_dict[engine_index]
req_id = task["request_id"]
if (
@@ -675,7 +677,7 @@ class CacheMessagerV1:
task["sended_layer_id"] = -1
if end_layer_idx == self.num_layers - 1:
with self.engine_cache_task_thread_lock:
for engine_idx in engine_indexes:
for engine_idx, _ in batch_engine_signals:
task = self.idx_cache_task_dict[engine_idx]
if task["status"] == "finished" or ("error" in task["status"]):
target_id = int(task["rdma_ports"][self.rank])
@@ -711,7 +713,8 @@ class CacheMessagerV1:
layer_id = kv_signal_data[1].numpy().tolist()
if layer_id == self.num_layers - 1:
logger.info(f"tasks_count: {tasks_count}, layer_id: {layer_id}")
batch_engine_ids = []
batch_engine_signals = []
# format for signal to put in cache_prefilled_engine_ids_queue: [(engine_idx1, prefilled_token_num1), (engine_idx2, prefilled_token_num2)]
with self.engine_cache_task_thread_lock:
for bi in range(tasks_count):
engine_idx = kv_signal_data[3 * bi + 2].numpy().tolist()
@@ -721,9 +724,9 @@ class CacheMessagerV1:
self.engine_cache_tasks[engine_idx]["prefilled_token_num"] = (
chuck_token_offset + current_seq_len
)
batch_engine_ids.append(engine_idx)
batch_engine_signals.append((engine_idx, chuck_token_offset + current_seq_len))
if layer_id == 0:
self.cache_prefilled_engine_ids_queue.put(batch_engine_ids)
self.cache_prefilled_engine_ids_queue.put(batch_engine_signals)
except Exception as e:
logger.error(f"Consume signals get exception: {e}")

View File

@@ -596,31 +596,47 @@ class EngineService:
batch=num_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
time.sleep(0.05)
else:
break
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
@@ -887,6 +903,7 @@ class EngineService:
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
else:

View File

@@ -122,6 +122,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ENABLE_METRIC_LABELS": lambda: bool(int(os.getenv("FD_ENABLE_METRIC_LABELS", "0"))),
# Default label values in metrics.
"FD_DEFAULT_METRIC_LABEL_VALUES": lambda: os.getenv("FD_DEFAULT_METRIC_LABEL_VALUES", "{}"),
# Enable offline perf test mode for PD disaggregation
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
}

View File

@@ -275,6 +275,7 @@ class SplitwiseConnector:
decode_diagg = task.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"] = self.cfg.disaggregate_info["cache_info"]
task.disaggregate_info["cache_info"]["rdma"]["current_id"] = current_id
task.disaggregate_info["role"] = "decode"
self._send_message(addr, "prefill", [task])
task.disaggregate_info["cache_info"] = decode_diagg
task.disaggregate_info["role"] = "prefill"