[PD Disaggregation] remove splitwise deployment on single node and refine the code (#4891)

* remove splitwise deployment on single node and refine the code

* up

* up

* up

* add test

* up
This commit is contained in:
Juncai
2025-11-14 09:56:53 +08:00
committed by GitHub
parent 9703108c28
commit 36822fa49c
24 changed files with 626 additions and 963 deletions

View File

@@ -160,8 +160,8 @@ class EngineService:
self.insert_task_to_worker_thread.start()
self.token_processor.tasks_queue = self.engine_worker_queue
self.token_processor.run()
if self.cfg.scheduler_config.splitwise_role != "mixed":
self._process_splitwise_task()
if self.cfg.scheduler_config.splitwise_role == "decode":
self._decode_process_splitwise_requests()
self._register_to_router()
@@ -329,54 +329,13 @@ class EngineService:
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1, allocated=False):
def insert_tasks(self, tasks: Union[List[Request], List[RequestOutput]], current_id=-1):
"""
Insert tasks to engine.
"""
for task in tasks:
start_span_request("DEQUEUE", task, trace.SpanKind.CONSUMER)
# TODO 返回至 scheduler
if allocated:
current_tasks = []
for task in tasks:
cur_task_idx = self.resource_manager.req_dict[task.request_id]
del self.resource_manager.req_dict[task.request_id]
cur_task = self.resource_manager.tasks_list[cur_task_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not task.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
continue
cur_task.prompt_token_ids[0] = task.outputs.token_ids[0]
cur_task.num_cached_tokens = task.num_cached_tokens
if (
self.cfg.speculative_config.method in ["mtp"]
and self.cfg.scheduler_config.splitwise_role == "decode"
):
cur_task.draft_token_ids = copy.deepcopy(task.outputs.draft_token_ids)
if task.error_code != 200:
self.resource_manager.stop_flags[cur_task_idx] = True
self.resource_manager.tasks_list[cur_task_idx] = None
self.resource_manager._recycle_block_tables(cur_task)
if task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
current_tasks.append(cur_task)
if current_tasks:
self.engine_worker_queue.put_tasks((current_tasks, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put task to engine worker queue, task:{current_tasks}")
return True
self.resource_manager.check_and_free_block_tables()
if not isinstance(tasks, list):
@@ -445,8 +404,53 @@ class EngineService:
else:
self.update_mm_requests_chunk_size(tasks)
self.engine_worker_queue.put_tasks((tasks, self.resource_manager.real_bsz))
if is_prefill and self.cfg.scheduler_config.name != "splitwise":
self.engine_worker_queue.available_prefill_instances.put(1)
return True
def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
"""
insert prefilled requests into engine worker queue.
Args:
request_outputs: a list of RequestOutput sent by prefill instance
"""
to_infer_reqs = []
for req_out in request_outputs:
solt_idx = self.resource_manager.req_dict[req_out.request_id]
del self.resource_manager.req_dict[req_out.request_id]
cur_req = self.resource_manager.tasks_list[solt_idx]
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if not req_out.outputs.token_ids: # first token is eos in Prefill, just recycle resource and continue
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.llm_logger.warning(f"{req_out.request_id} need not decode after first token")
continue
cur_req.prompt_token_ids[0] = req_out.outputs.token_ids[0]
cur_req.num_cached_tokens = req_out.num_cached_tokens
if self.cfg.speculative_config.method in ["mtp"] and self.cfg.scheduler_config.splitwise_role == "decode":
cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids)
if req_out.error_code != 200:
self.resource_manager.stop_flags[solt_idx] = True
self.resource_manager.tasks_list[solt_idx] = None
self.resource_manager._recycle_block_tables(cur_req)
if req_out.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[req_out.request_id]
self.scheduler.put_results([req_out])
self.llm_logger.warning(
f"{req_out.request_id} prefill failed with msg:{req_out.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[req_out.request_id] = 1
to_infer_reqs.append(cur_req)
if to_infer_reqs:
self.engine_worker_queue.put_tasks((to_infer_reqs, self.resource_manager.real_bsz))
self.llm_logger.debug(f"put requests to engine worker queue, task:{to_infer_reqs}")
return True
def task_is_finished(self, index):
@@ -636,8 +640,9 @@ class EngineService:
if len(tasks) == 0:
time.sleep(0.001)
continue
if self.cfg.splitwise_version == "v2" and self.cfg.scheduler_config.splitwise_role == "decode":
# the task in decode instance will processed in _process_splitwise_task thread
if self.cfg.scheduler_config.splitwise_role == "decode":
# Decode will instert the request sent by prefill to engine,
# so the task sent by client will be ignored
continue
llm_logger.debug(f"get tasks from scheduler: {tasks}")
@@ -684,7 +689,14 @@ class EngineService:
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
self.llm_logger.debug(f"get tasks from scheduler: {tasks}")
if self.cfg.scheduler_config.splitwise_role == "decode":
# Decode will instert the request sent by prefill to engine,
# so the task sent by client will be ignored
is_fetching = False
return
self.llm_logger.debug(f"get tasks from {type(self.scheduler)}: {tasks}")
if self.cfg.scheduler_config.splitwise_role != "mixed":
need_delete_tasks = []
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
@@ -705,6 +717,7 @@ class EngineService:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
self.llm_logger.info("wait for preallocate_resource_in_p")
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
@@ -864,7 +877,7 @@ class EngineService:
request.llm_engine_recv_req_timestamp = time.time()
start_span("ENQUEUE_ZMQ", data, trace.SpanKind.PRODUCER)
main_process_metrics.requests_number.inc()
self.llm_logger.debug(f"Receive request: {request}")
self.llm_logger.debug(f"Receive request from api server: {request}")
except Exception as e:
self.llm_logger.error(f"Receive request error: {e}, {traceback.format_exc()!s}")
err_msg = str(e)
@@ -997,156 +1010,126 @@ class EngineService:
except Exception as e:
llm_logger.error(f"Unexcepted error happend: {e}, {traceback.format_exc()!s}")
def _process_splitwise_task(self):
def _decode_process_splitwise_requests(self):
"""
Processing tasks from engine worker queue in splitwise deployment.
For v0 version, prefill instance gets tasks from engine worker queue.
For v1 and v2 version, decode instance gets raw tasks from engine worker queue to preallocate resources,
and decode instance gets prefilled tasks from engine worker queue to generate tokens.
TODO: unifiy the communication between decode and prefill instances.
Decode processes requests from engine worker queue, which are sent by prefill.
TODO: merge this function to the schedule function in resource manager
"""
allocate_resource_requests: list[Request] = []
prefilled_request_ouputs: list[RequestOutput] = []
def receiver_loop():
waiting_resource_requests = []
waiting_ready_tasks = []
def _fetch_requests():
if self.engine_worker_queue.disaggregate_queue_empty():
return
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
tasks = item[1]
if isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
allocate_resource_requests.extend(tasks)
elif isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
prefilled_request_ouputs.extend(tasks)
def _process_allocate_resource_requests():
processed_indices = []
for idx, task in enumerate(allocate_resource_requests):
is_success = False
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_infos([task], -1)
processed_indices.append(idx)
is_success = True
else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.insert_tasks([task])
processed_indices.append(idx)
is_success = True
if not is_success:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
self.split_connector.send_cache_infos([task], -1)
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
allocate_resource_requests.pop(idx)
def _process_prefilled_requests():
nonlocal prefilled_request_ouputs
ready_request_outputs = []
waiting_request_outputs = []
# Waiting for the api_server and scheduler in decode to
# receive the request sent by the client
def _decode_process_prefilled_task_v0_scheduler(input_tasks):
ready_tasks = []
waiting_tasks = []
for task in input_tasks:
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
ready_tasks.append(task)
else:
waiting_tasks.append(task)
self.insert_tasks(ready_tasks, allocated=True)
if self.cfg.splitwise_version in ("v0", "v2"):
self.scheduler.put_results(ready_tasks)
return waiting_tasks
for task in prefilled_request_ouputs:
if not hasattr(self.scheduler, "has_request") or self.scheduler.has_request(task.request_id):
ready_request_outputs.append(task)
else:
waiting_request_outputs.append(task)
prefilled_request_ouputs = waiting_request_outputs
if self.cfg.splitwise_version == "v1":
# decode return first token to client
self.scheduler.put_results(ready_request_outputs)
if not envs.ENABLE_V1_KVCACHE_SCHEDULER:
self._insert_prefilled_requests(ready_request_outputs)
else:
for task in ready_request_outputs:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if (
not task.outputs.token_ids
): # first token is eos in Prefill, just recycle resource and continue
cur_req = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_req.idx] = True
self.resource_manager.tasks_list[cur_req.idx] = None
self.resource_manager._free_blocks(cur_req)
if cur_req.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(f"{task.request_id} need not decode after first token")
del self.resource_manager.requests[task.request_id]
del self.resource_manager.req_dict[task.request_id]
continue
if task.error_code != 200:
cur_req = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_req.idx] = True
self.resource_manager.tasks_list[cur_req.idx] = None
self.resource_manager._free_blocks(cur_req)
if cur_req.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
def decode_loop():
while self.running:
try:
processed_indices = []
for idx, task in enumerate(waiting_resource_requests):
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
self.split_connector.send_cache_infos([task], -1)
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
else:
if self.resource_manager.is_resource_sufficient(task.prompt_token_ids_len):
self.insert_tasks([task])
self.llm_logger.info(f"Resource available, processing task {task.request_id}")
processed_indices.append(idx)
else:
self.llm_logger.debug(f"Still waiting for resources {task.request_id}")
break
for idx in sorted(processed_indices, reverse=True):
waiting_resource_requests.pop(idx)
waiting_ready_tasks = _decode_process_prefilled_task_v0_scheduler(waiting_ready_tasks)
if self.engine_worker_queue.disaggregate_queue_empty():
time.sleep(0.001)
else:
items = self.engine_worker_queue.get_disaggregated_tasks()
for item in items:
role = item[0]
tasks = item[1]
# prefill instance gets tasks from engine worker queue
if role == "prefill":
for task in tasks:
task.max_tokens = task.min_tokens = 2
self.insert_tasks(tasks)
# decode instance gets tasks from engine worker queue
elif role == "decode":
if isinstance(tasks[0], RequestOutput):
self.llm_logger.debug(f"receive prefilled tasks, {tasks}")
if not isinstance(tasks, list):
tasks = [tasks]
for task in tasks:
task.finished = False
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
for task in tasks:
if envs.FD_ENABLE_INTERNAL_ADAPTER:
if (
not task.outputs.token_ids
): # first token is eos in Prefill, just recycle resource and continue
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.llm_logger.warning(
f"{task.request_id} need not decode after first token"
)
del self.resource_manager.requests[task.request_id]
del self.resource_manager.req_dict[task.request_id]
continue
if task.error_code != 200:
cur_task = self.resource_manager.requests[task.request_id]
self.resource_manager.stop_flags[cur_task.idx] = True
self.resource_manager.tasks_list[cur_task.idx] = None
self.resource_manager._free_blocks(cur_task)
if cur_task.request_id in self.token_processor.tokens_counter:
del self.token_processor.tokens_counter[task.request_id]
self.scheduler.put_results([task])
self.llm_logger.warning(
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
else:
waiting_ready_tasks.extend(_decode_process_prefilled_task_v0_scheduler(tasks))
elif isinstance(tasks[0], Request):
self.llm_logger.debug(f"receive tasks to preallocate resource, {tasks}")
if len(waiting_resource_requests):
self.llm_logger.info(f"Waiting for resource for task {tasks[0].request_id}")
waiting_resource_requests.extend(tasks)
else:
new_waiting = []
for task in tasks:
can_allocate_resource = False
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if self.resource_manager.preallocate_resource_in_d(task):
self.split_connector.send_cache_infos([task], -1)
can_allocate_resource = True
else:
if self.resource_manager.is_resource_sufficient(
task.prompt_token_ids_len
):
self.insert_tasks([task])
can_allocate_resource = True
if can_allocate_resource is False:
if not self.enable_decode_cache_task:
task.error_msg = "Not enough resources"
new_waiting.append(task)
if new_waiting:
if not self.enable_decode_cache_task:
self.split_connector.send_cache_infos(new_waiting, -1)
else:
waiting_resource_requests.extend(new_waiting)
self.llm_logger.info(
f"Added {len(new_waiting)} tasks to waiting queue"
)
else:
raise ValueError(f"Unsupported task type: {type(tasks[0])}")
_fetch_requests()
_process_allocate_resource_requests()
_process_prefilled_requests()
time.sleep(0.001)
except Exception as e:
self.llm_logger.error(f"Error in main loop: {e}")
time.sleep(0.1)
self.llm_logger.error(
f"Error in main loop of decode_process_splitwise_requests: " f"{e}, {traceback.format_exc()}"
)
time.sleep(0.01)
threading.Thread(target=receiver_loop, daemon=True).start()
threading.Thread(target=decode_loop, daemon=True).start()
def start_cache_service(self, device_ids, ipc_signal_suffix):
return self.resource_manager.cache_manager.launch_cache_manager(