[Optimize][Cherry-pick] Robust stabilty for PD deployment #5338 (#5395)

* [Optimize] Robust stabilty for PD deployment

---------

Co-authored-by: Kaipeng Deng <dengkaipeng@baidu.com>
This commit is contained in:
chenjian
2025-12-15 18:58:09 +08:00
committed by GitHub
parent f133ce501c
commit 4c76171b57
12 changed files with 161 additions and 41 deletions

View File

@@ -125,6 +125,7 @@ class EngineService:
split_connector=self.split_connector,
)
self.token_processor.set_resource_manager(self.resource_manager)
# self.token_processor.enable_monitor_hang()
self.partial_chunked_tokens = [0] * (self.cfg.max_num_partial_prefills + 1)
for idx in range(1, self.cfg.max_num_partial_prefills + 1):
@@ -716,7 +717,6 @@ class EngineService:
is_fetching = False
return
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
if self.cfg.scheduler_config.splitwise_role == "prefill":
for task in tasks:

View File

@@ -182,6 +182,7 @@ class Request:
self.async_process_futures = []
self.error_message = None
self.error_code = None
self.last_recv_token_time = None
def __getstate__(self):
"""

View File

@@ -199,6 +199,31 @@ class ResourceManagerV1(ResourceManager):
self.bos_client = None
self.async_preprocess_pool = ThreadPoolExecutor(max_workers=4)
if self.config.scheduler_config.splitwise_role == "decode":
self.preallocated_requests_timestamp = {}
threading.Thread(target=self._monitor_decode_kv_block_recycling, daemon=True).start()
def _monitor_decode_kv_block_recycling(self):
while True:
try:
with self.lock:
need_recycle_request_ids = []
for request_id, timestamp in self.preallocated_requests_timestamp.items():
if time.time() - timestamp >= envs.FD_GET_FIRST_TOKEN_FROM_P_TIMEOUT:
need_recycle_request_ids.append(request_id)
for request_id in need_recycle_request_ids:
del self.preallocated_requests_timestamp[request_id]
for request_id in need_recycle_request_ids:
if request_id in self.requests:
self.pre_recycle_resource(request_id)
llm_logger.error(
f"Recycle block ids for request {request_id} forcefully, due to get first token from P timeout."
f"after recycle: {self.info()}"
)
time.sleep(10)
except Exception as e:
llm_logger.error(f"Monitor recycle block ids in D error: {e}, {str(traceback.format_exc())}")
time.sleep(10)
def allocated_slots(self, request: Request):
return len(request.block_tables) * self.config.cache_config.block_size
@@ -227,8 +252,17 @@ class ResourceManagerV1(ResourceManager):
def reschedule_preempt_task(self, request_id):
with self.lock:
if request_id in self.to_be_rescheduled_request_id_set and request_id in self.requests:
request = self.requests[request_id]
self.waiting.appendleft(request)
if self.config.scheduler_config.splitwise_role == "decode":
request = self.requests[request_id]
self.tasks_list[request.idx] = None
self.stop_flags[request.idx] = True
if request_id in self.requests:
del self.requests[request_id]
if request_id in self.req_dict:
del self.req_dict[request_id]
else:
request = self.requests[request_id]
self.waiting.appendleft(request)
self.to_be_rescheduled_request_id_set.remove(request_id)
def _info_each_block(self):
@@ -262,20 +296,10 @@ class ResourceManagerV1(ResourceManager):
continue
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.config.scheduler_config.splitwise_role == "decode":
self.tasks_list[preempted_req.idx] = None
self.stop_flags[preempted_req.idx] = True
if preempted_req.request_id in self.requests:
del self.requests[preempted_req.request_id]
if preempted_req.request_id in self.req_dict:
del self.req_dict[preempted_req.request_id]
self._free_blocks(preempted_req)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
else:
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
self._free_blocks(preempted_req)
preempted_req.cached_block_num = 0
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
preempted_reqs.append(preempted_req)
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
@@ -1014,6 +1038,7 @@ class ResourceManagerV1(ResourceManager):
self.stop_flags[request.idx] = False
self.requests[request.request_id] = request
self.req_dict[request.request_id] = allocated_position
self.preallocated_requests_timestamp[request.request_id] = time.time()
return True
def has_resource_for_prefilled_req(self, request_id: str):
@@ -1032,23 +1057,26 @@ class ResourceManagerV1(ResourceManager):
NOTE: GPU resources should be checked in advance to ensure they are sufficient for the prefilled request.
"""
assert self.config.scheduler_config.splitwise_role == "decode", "Only D instance can call this method"
if request_output.request_id not in self.requests:
self.logger.error(f"Request {request_output.request_id} not found in requests")
return
request = self.requests[request_output.request_id]
with self.lock:
if request_output.request_id not in self.requests:
llm_logger.error(f"Request {request_output.request_id} not found in requests")
return
request = self.requests[request_output.request_id]
# update request and insert to running
request.output_token_ids.append(request_output.outputs.token_ids[0])
request.num_cached_tokens = request_output.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
# update request and insert to running
request.output_token_ids.append(request_output.outputs.token_ids[0])
if request.request_id in self.preallocated_requests_timestamp:
del self.preallocated_requests_timestamp[request.request_id]
request.num_cached_tokens = request_output.num_cached_tokens
if (
self.config.speculative_config.method in ["mtp"]
and self.config.scheduler_config.splitwise_role == "decode"
):
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)
request.need_prefill_tokens = len(request.prompt_token_ids) + 1
request.inference_start_time = time.time()
request.schedule_start_time = time.time()
self.running.append(request)
def _free_blocks(self, request: Request):
if self.config.cache_config.enable_prefix_caching:
@@ -1109,6 +1137,7 @@ class ResourceManagerV1(ResourceManager):
del self.requests[req_id]
if req_id in self.req_dict:
del self.req_dict[req_id]
llm_logger.info(f"after recycle: {self.info()}")
except Exception as e:
llm_logger.error(f"finish_request err: {e}, {str(traceback.format_exc())}")
finally: