[Optimize][Cherry-pick] Robust stabilty for PD deployment #5338 (#5395)

* [Optimize] Robust stabilty for PD deployment

---------

Co-authored-by: Kaipeng Deng <dengkaipeng@baidu.com>
This commit is contained in:
chenjian
2025-12-15 18:58:09 +08:00
committed by GitHub
parent f133ce501c
commit 4c76171b57
12 changed files with 161 additions and 41 deletions

View File

@@ -125,6 +125,7 @@ class EngineService:
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
# self.token_processor.enable_monitor_hang()
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
@@ -716,7 +717,6 @@ class EngineService:
is_fetching = False
return
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:

View File

@@ -182,6 +182,7 @@ class Request:
self.async_process_futures = []
self.error_message = None
self.error_code = None
self.last_recv_token_time = None
def __getstate__(self):
"""

View File

@@ -199,6 +199,31 @@ class ResourceManagerV1(ResourceManager):
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
if self.config.scheduler_config.splitwise_role == "decode":
self.preallocated_requests_timestamp = {}
threading.Thread(target=self._monitor_decode_kv_block_recycling, daemon=True).start()
def _monitor_decode_kv_block_recycling(self):
while True:
try:
with self.lock:
need_recycle_request_ids = []
for request_id, timestamp in self.preallocated_requests_timestamp.items():
if time.time() - timestamp >= envs.FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT:
need_recycle_request_ids.append(request_id)
for request_id in need_recycle_request_ids:
del self.preallocated_requests_timestamp[request_id]
for request_id in need_recycle_request_ids:
if request_id in self.requests:
self.pre_recycle_resource(request_id)
llm_logger.error(
f"Recycle block ids for request {request_id} forcefully, due to get first token from P timeout."
f"after recycle: {self.info()}"
)
time.sleep(10)
except Exception as e:
llm_logger.error(f"Monitor recycle block ids in D error: {e}, {str(traceback.format_exc())}")
time.sleep(10)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -227,8 +252,17 @@ class ResourceManagerV1(ResourceManager):
def reschedule_preempt_task(self, request_id):
with self.lock:
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id]
self.waiting.appendleft(request)
if self.config.scheduler_config.splitwise_role == "decode":
request = self.requests[request_id]
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
if request_id in self.requests:
del self.requests[request_id]
if request_id in self.req_dict:
del self.req_dict[request_id]
else:
request = self.requests[request_id]
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def _info_each_block(self):
@@ -262,20 +296,10 @@ class ResourceManagerV1(ResourceManager):
continue
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
self.tasks_list[preempted_req.idx] = None
self.stop_flags[preempted_req.idx] = True
if preempted_req.request_id in self.requests:
del self.requests[preempted_req.request_id]
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -1014,6 +1038,7 @@ class ResourceManagerV1(ResourceManager):
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
self.preallocated_requests_timestamp[request.request_id] = time.time()
return True
def has_resource_for_prefilled_req(self, request_id: str):
@@ -1032,23 +1057,26 @@ class ResourceManagerV1(ResourceManager):
NOTE: GPU resources should be checked in advance to ensure they are sufficient for the prefilled request.
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
if request_output.request_id not in self.requests:
self.logger.error(f"Request {request_output.request_id} not found in requests")
return
request = self.requests[request_output.request_id]
with self.lock:
if request_output.request_id not in self.requests:
llm_logger.error(f"Request {request_output.request_id} not found in requests")
return
request = self.requests[request_output.request_id]
# update request and insert to running
request.output_token_ids.append(request_output.outputs.token_ids[0])
request.num_cached_tokens = request_output.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
# update request and insert to running
request.output_token_ids.append(request_output.outputs.token_ids[0])
if request.request_id in self.preallocated_requests_timestamp:
del self.preallocated_requests_timestamp[request.request_id]
request.num_cached_tokens = request_output.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
@@ -1109,6 +1137,7 @@ class ResourceManagerV1(ResourceManager):
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
llm_logger.info(f"after recycle: {self.info()}")
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
finally:

View File

@@ -148,6 +148,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
"FD_FILL_BITMASK_BATCH": lambda: int(os.getenv("FD_FILL_BITMASK_BATCH", "4")),
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
# Timeout for first token from P in PD disaggregation
"FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT": lambda: int(os.getenv("FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT", "300")),
# Timeout for token processor health check
"FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT": lambda: int(os.getenv("FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT", "120")),
"FD_OUTPUT_TOKEN_HANG_TIMEOUT": lambda: int(os.getenv("FD_OUTPUT_TOKEN_HANG_TIMEOUT", "60")),
}

View File

@@ -127,6 +127,49 @@ class TokenProcessor:
self._finalizer = weakref.finalize(self, self._cleanup_resources)
self._batch_result_buffer = None
# health monitor
self.timestamp_for_alive_before_handle_batch = None
self.timestamp_for_alive_after_handle_batch = None
self.health_lock = threading.Lock()
self.engine_output_token_hang = False
def healthy(self):
"""
whether token processor is healthy
"""
with self.health_lock:
if self.timestamp_for_alive_after_handle_batch is None: # has entered handle batch
if (
self.timestamp_for_alive_before_handle_batch is not None
and time.time() - self.timestamp_for_alive_before_handle_batch
> envs.FD_TOKEN_PROCESSOR_HEALTH_TIMEOUT
):
return False
else:
return True
if self.engine_output_token_hang:
return False
return True
def enable_monitor_hang(self):
self.monitor_thread = threading.Thread(target=self._monitor_output_token_hang)
self.monitor_thread.start()
def _monitor_output_token_hang(self):
while True:
for i in range(self.resource_manager.max_num_seqs):
if self.resource_manager.stop_flags[i]:
continue
task = self.resource_manager.tasks_list[i]
if (
task.last_recv_token_time
and time.time() - task.last_recv_token_time > envs.FD_OUTPUT_TOKEN_HANG_TIMEOUT
):
llm_logger.error(f"Task {task.request_id} hangs")
self.engine_output_token_hang = True
time.sleep(1)
def _cleanup_resources(self):
"""Cleaning up shared memory resources"""
if hasattr(self, "prefill_time_signal"):
@@ -190,6 +233,7 @@ class TokenProcessor:
if self.resource_manager.requests[request_id].idx >= (
batch_size - 1
): # No more token generated for preempted request
self.resource_manager.requests[request_id].last_recv_token_time = None
self.resource_manager.reschedule_preempt_task(request_id)
def _process_per_token(self, task, batch_id: int, token_ids: np.ndarray, result: RequestOutput, is_prefill: bool):
@@ -220,12 +264,12 @@ class TokenProcessor:
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
)
llm_logger.info(f"{self.resource_manager.info()}")
if self.cfg.speculative_config.method:
self._compute_speculative_status()
if not is_prefill:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, batch_id, task, result, is_prefill)
llm_logger.info(f"{self.resource_manager.info()}")
break
return result
@@ -417,7 +461,14 @@ class TokenProcessor:
continue
llm_logger.debug(f"rank_id {rank_id} self.output_tokens[0, 0] {self.output_tokens[0, 0]}")
self._process_prefill_metrics()
with self.health_lock:
self.timestamp_for_alive_before_handle_batch = time.time()
self.timestamp_for_alive_after_handle_batch = None
self._process_batch_output()
with self.health_lock:
self.timestamp_for_alive_before_handle_batch = None
self.timestamp_for_alive_after_handle_batch = time.time()
except Exception as e:
llm_logger.info(f"while get input_data error: {e} {traceback.format_exc()!s}")
@@ -682,10 +733,12 @@ class TokenProcessor:
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] < 0):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
task.last_recv_token_time = None
self.resource_manager.reschedule_preempt_task(task_id)
continue
else:
token_id = int(tokens[i, 0])
@@ -696,9 +749,16 @@ class TokenProcessor:
if not recovery_stop and token_id < 0:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
task.last_recv_token_time = None
self.resource_manager.reschedule_preempt_task(task_id)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# In D instance, if preempted, error has been reported and resource recycled, tokens generated async not need to be handled
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
continue
if task.get("prefill_chunk_info", None) is not None:
prefill_chunk_num = task.get("prefill_chunk_num", 0)
task.prefill_chunk_num = prefill_chunk_num + 1
@@ -769,6 +829,9 @@ class TokenProcessor:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
task.last_recv_token_time = time.time()
if token_id == 0:
llm_logger.error(f"Request: {task_id} generates token_id 0, maybe wrong inference.")
if self.use_logprobs:
if self.cfg.speculative_config.method:
@@ -804,12 +867,12 @@ class TokenProcessor:
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
)
llm_logger.info(f"{self.resource_manager.info()}")
if self.cfg.speculative_config.method:
self._compute_speculative_status()
if not is_prefill:
self._record_completion_metrics(task, current_time)
self._recycle_resources(task_id, i, task, result, is_prefill)
llm_logger.info(f"{self.resource_manager.info()}")
break
llm_logger.debug(f"get response from infer: {result}")

View File

@@ -98,6 +98,12 @@ class InternalAdapter:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
elif task["cmd"] == "connect_rdma":
self.engine.engine_worker_queue.put_connect_rdma_task(task)
elif task["cmd"] == "check_health":
is_health = self.engine.token_processor.healthy()
result = {"task_id": task_id_str, "result": is_health}
logger.debug(f"Response for task: {task_id_str}: is_health {is_health}")
with self.response_lock:
self.recv_control_cmd_server.response_for_control_cmd(task_id_str, result)
except Exception as e:
logger.error(f"handle_control_cmd got error: {e}, {traceback.format_exc()!s}")

View File

@@ -386,6 +386,13 @@ class SplitwiseConnector:
if msg_type == "decode" or msg_type == "prefill":
payload = [output.to_dict() for output in payload]
need_delete_keys = ["video_features", "image_features", "audio_features"]
for tmp_data in payload:
if "multimodal_inputs" not in tmp_data:
continue
for tmp_key in need_delete_keys:
if tmp_key in tmp_data["multimodal_inputs"]:
del tmp_data["multimodal_inputs"][tmp_key]
json_data = json.dumps({"type": msg_type, "payload": payload}).encode("utf-8")
return json_data

View File

@@ -490,6 +490,9 @@ class GPUModelRunner(ModelRunnerBase):
)
rope_3d_position_ids["max_tokens_lst"].append(request.get("max_tokens", 2048))
def get_num_running_request(self):
return self.scheduler_config.max_num_seqs - paddle.sum(self.share_inputs["stop_flags"]).item()
def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = None):
"""
Process scheduler output tasks, used when ENABLE_V1_KVCACHE_SCHEDULER=1

View File

@@ -179,6 +179,9 @@ class GpuWorker(WorkerBase):
"""Get current model"""
return self.model_runner.get_model()
def get_num_running_request(self):
return self.model_runner.get_num_running_request()
def initialize_cache(self, num_gpu_blocks: int) -> None:
"""Initizlize the KV Cache with accurate num_gpu_blocks"""
# accurate cache size

View File

@@ -410,7 +410,7 @@ class PaddleDisWorkerProc:
# Currently, only support single node
self.nnode = int((tp_size + 7) // 8)
req_ids = []
num_running_requests = 0
cur_max_bsz_index = 0
tp_rank = self.local_rank % tp_size
self.model_weights_signal = np.zeros([1], dtype=np.int32)
@@ -485,17 +485,18 @@ class PaddleDisWorkerProc:
req_dicts = []
for req_dict, bsz in tasks:
num_running_requests = int(bsz)
cur_max_bsz_index = int(bsz)
req_dicts.extend(req_dict)
req_ids = [req.request_id for req in req_dicts]
logger.info(
f"Rank: {self.local_rank}, num_running_requests: {num_running_requests}, "
f"Rank: {self.local_rank}, cur_max_bsz_index: {cur_max_bsz_index}, num_running_requests: {self.worker.get_num_running_request()} "
f"num_insert_requests: {len(req_dicts)}, req_ids: {req_ids}"
)
# Process prefill inputs
self.worker.preprocess_new_task(req_dicts, num_running_requests)
self.worker.preprocess_new_task(req_dicts, cur_max_bsz_index)
if (not self.parallel_config.use_ep) and (not self.worker.model_runner.not_need_stop()):
if self.ranks > 1:
@@ -507,7 +508,7 @@ class PaddleDisWorkerProc:
# Execute model to generate token. The generated token will be written to the buffer.
# These generated tokens can be obtained through get_output op.
start_execute_time = time.time()
self.worker.execute_model(req_dicts, num_running_requests)
self.worker.execute_model(req_dicts, cur_max_bsz_index)
self.exist_prefill_task_signal.value[0] = self.worker.exist_prefill()
logger.debug(f"execute model cost: {time.time()-start_execute_time:.5f} s")