[PD Disaggregation] [Refine] Refine splitwise deployment (#5151)

* Refine splitwise deployment

* up
This commit is contained in:
Juncai
2025-11-21 15:30:24 +08:00
committed by GitHub
parent 2d1dade5e2
commit f9b0545a7f
15 changed files with 371 additions and 492 deletions

View File

@@ -23,7 +23,7 @@ import time
import traceback
import weakref
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
@@ -324,18 +324,18 @@ class EngineService:
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1):
def insert_tasks(self, tasks: List[Request], current_id=-1):
"""
Insert tasks to engine.
Allocate resource and insert tasks to engine.
Used in v0_kvcache_scheduler.
"""
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
tasks = [tasks]
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
@@ -388,7 +388,11 @@ class EngineService:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
self.split_connector.send_cache_infos(tasks, current_id)
if self.cfg.scheduler_config.splitwise_role == "prefill":
self.split_connector.send_cache_info_to_messager(tasks, current_id)
elif self.cfg.scheduler_config.splitwise_role == "decode":
self.split_connector.send_cache_info_to_prefill(tasks)
if not is_decode:
self.llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
for task in tasks:
@@ -406,7 +410,8 @@ class EngineService:
def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
"""
insert prefilled requests into engine worker queue.
Decode insert prefilled requests into engine worker queue.
Used in v1_kvcache_scheduler.
Args:
request_outputs: a list of RequestOutput sent by prefill instance
"""
@@ -640,8 +645,9 @@ class EngineService:
time.sleep(0.001)
continue
if self.cfg.scheduler_config.splitwise_role == "decode":
# Decode will instert the request sent by prefill to engine,
# so the task sent by client will be ignored
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
continue
llm_logger.debug(f"get tasks from scheduler: {tasks}")
@@ -692,8 +698,9 @@ class EngineService:
trace_print(LoggingEventName.REQUEST_QUEUE_END, task.request_id, getattr(task, "user", ""))
if self.cfg.scheduler_config.splitwise_role == "decode":
# Decode will instert the request sent by prefill to engine,
# so the task sent by client will be ignored
# TODO: refine scheduler to remove this limitation
# Decode will process and schedule the request sent by prefill to engine,
# so the same request sent by the decode api server will be ignored
is_fetching = False
return
@@ -744,11 +751,11 @@ class EngineService:
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
self.resource_manager.prerelease_resource(tmp_task)
self.resource_manager.pre_recycle_resource(tmp_task.request_id)
if self.cfg.scheduler_config.splitwise_role == "prefill":
# to send cache info to cache messager
if tasks:
self.split_connector.send_cache_infos(tasks, 0)
self.split_connector.send_cache_info_to_messager(tasks, 0)
# ensure cache tasks has sent to cache_messager
need_check_req_ids = [task.request_id for task in tasks]
while need_check_req_ids:
@@ -1002,7 +1009,7 @@ class EngineService:
else:
new_contents.append(content)
if len(new_contents):
llm_logger.debug(f"Send response for request id: {request_id}")
llm_logger.debug(f"Send response for request id: {request_id}, {new_contents}")
self.send_response_server.send_response(request_id, new_contents)
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
@@ -1041,7 +1048,7 @@ class EngineService:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_infos([task], -1)
self.split_connector.send_cache_info_to_prefill([task])
processed_indices.append(idx)
is_success = True
else:
@@ -1054,7 +1061,7 @@ class EngineService:
if not is_success:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
self.split_connector.send_cache_infos([task], -1)
self.split_connector.send_cache_info_to_prefill([task])
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
@@ -1067,13 +1074,16 @@ class EngineService:
nonlocal prefilled_request_ouputs
ready_request_outputs = []
waiting_request_outputs = []
# Waiting for the api_server and scheduler in decode to
# receive the request sent by the client
for task in prefilled_request_ouputs:
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
ready_request_outputs.append(task)
else:
waiting_request_outputs.append(task)
for req_output in prefilled_request_ouputs:
if hasattr(self.scheduler, "has_request") and not self.scheduler.has_request(req_output.request_id):
# ensure the api_server and scheduler in decode have
# received the request sent by the client
waiting_request_outputs.append(req_output)
continue
ready_request_outputs.append(req_output)
self.llm_logger.debug(f"there are enough resource for prefilled request: {req_output.request_id}")
prefilled_request_ouputs = waiting_request_outputs
if self.cfg.splitwise_version == "v1":
@@ -1083,35 +1093,27 @@ class EngineService:
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self._insert_prefilled_requests(ready_request_outputs)
else:
for task in ready_request_outputs:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if (
not task.outputs.token_ids
): # first token is eos in Prefill, just recycle resource and continue
cur_req = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_req.idx] = True
self.resource_manager.tasks_list[cur_req.idx] = None
self.resource_manager._free_blocks(cur_req)
if cur_req.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
del self.resource_manager.requests[task.request_id]
del self.resource_manager.req_dict[task.request_id]
continue
if task.error_code != 200:
cur_req = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_req.idx] = True
self.resource_manager.tasks_list[cur_req.idx] = None
self.resource_manager._free_blocks(cur_req)
if cur_req.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
for req_output in ready_request_outputs:
request_id = req_output.request_id
if envs.FD_ENABLE_INTERNAL_ADAPTER and not req_output.outputs.token_ids:
# first token is eos in Prefill, just recycle resource and continue
self.llm_logger.warning(f"{request_id} need not decode after first token")
self.resource_manager.pre_recycle_resource(request_id)
if request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[request_id]
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
if req_output.error_code != 200:
self.llm_logger.warning(
f"{request_id} prefill failed with msg:{req_output.error_msg}, recycle resource."
)
self.resource_manager.pre_recycle_resource(request_id)
if request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[request_id]
self.scheduler.put_results([req_output])
continue
self.token_processor.tokens_counter[request_id] = 1
self.resource_manager.add_prefilled_request(req_output)
self.llm_logger.debug(f"add prefilled request success, {request_id}")
def decode_loop():
while self.running: