[Optimize] Support and robust for tpN for PD (#4595)

* [Optimize] Support and robust for tpN for PD

* fix

* fix

* support dpM tpN for cache messager

* fix

* fix token counter

* fix bug for merge develop

* fix bug

* robust cache messager for v0
This commit is contained in:
chenjian
2025-11-03 15:38:31 +08:00
committed by GitHub
parent 7b35488779
commit 25498efcf3
9 changed files with 452 additions and 197 deletions

View File

@@ -310,11 +310,7 @@ class EngineService:
num_client=self.cfg.parallel_config.tensor_parallel_size,
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=min(
self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank
+ self.cfg.parallel_config.local_data_parallel_id,
self.cfg.parallel_config.data_parallel_size - 1,
),
local_data_parallel_id=self.cfg.parallel_config.local_data_parallel_id,
)
def insert_tasks(self, tasks, current_id=-1, allocated=False):
@@ -656,39 +652,60 @@ class EngineService:
self.cfg.max_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
max_num_batched_tokens = self.cfg.scheduler_config.max_num_batched_tokens
else:
max_num_batched_tokens = self.cfg.model_config.max_model_len
tasks = self.scheduler.get_requests(
available_blocks=self.cfg.cache_config.max_block_num_per_seq,
block_size=self.cfg.cache_config.block_size,
reserved_output_blocks=self.cfg.cache_config.enc_dec_block_num,
max_num_batched_tokens=self.cfg.model_config.max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
batch=num_prefill_batch,
)
if self.cfg.scheduler_config.splitwise_role != "mixed":
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
need_delete_tasks = []
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
if envs.FD_OFFLINE_PERF_TEST_FOR_PD:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
while True:
self.split_connector.send_splitwise_tasks([task], task.idx)
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} ask D resource failed, try again.")
time.sleep(0.05)
else:
break
else:
for task in tasks:
# assure can allocate block ids in P
while not self.resource_manager.preallocate_resource_in_p(task):
time.sleep(0.005)
self.llm_logger.info(f"ask D resource for req_id: {task.request_id}")
self.split_connector.send_splitwise_tasks([task], task.idx)
for task in tasks:
if self.cfg.scheduler_config.splitwise_role != "mixed":
# assure fetch block ids from D
status, msg = self.split_connector.check_decode_allocated(task)
if not status:
self.llm_logger.error(f"{task.request_id} prefill failed with msg:{msg}.")
self.scheduler.put_results(
[
RequestOutput(
request_id=task.request_id,
finished=True,
error_code=500,
error_msg=msg,
)
]
)
need_delete_tasks.append(task)
continue
for tmp_task in need_delete_tasks:
tasks.remove(tmp_task)
# release resource in P
@@ -930,7 +947,7 @@ class EngineService:
for request_id, contents in results.items():
new_contents = []
for content in contents:
if isinstance(content, RequestOutput):
if isinstance(content, RequestOutput) and content.outputs is not None:
decode_type = content.outputs.decode_type
delta_text = ""
if decode_type == 0:
@@ -1035,6 +1052,7 @@ class EngineService:
f"{task.request_id} prefill failed with msg:{task.error_msg}, recycle resource."
)
continue
self.token_processor.tokens_counter[task.request_id] = 1
self.resource_manager.insert_task_for_decoding(task)
else: