[Feature] support ep in mixed mode (#3001)

* [LLM] support ep

* Update worker_process.py

* Update expert_service.py

* Update worker_process.py

* format files
This commit is contained in:
ltd0924
2025-07-30 20:43:39 +08:00
committed by GitHub
parent bd29b2aaca
commit d17886de19
4 changed files with 58 additions and 52 deletions

View File

@@ -225,6 +225,9 @@ class Config:
else:
self.is_master = False
if self.tensor_parallel_size <= self.worker_num_per_node:
self.is_master = True
import paddle
self.paddle_commit_id = paddle.version.commit

View File

@@ -50,8 +50,9 @@ class ExpertService:
cfg (Config): Config object containing all the configuration parameters.
"""
self.cfg = cfg
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
end_pos = ((local_data_parallel_id + 1) * self.cfg.tensor_parallel_size) % self.cfg.worker_num_per_node
start_pos = (local_data_parallel_id * self.cfg.tensor_parallel_size) % cfg.worker_num_per_node
end_pos = start_pos + self.cfg.tensor_parallel_size
if cfg.splitwise_role != "mixed":
self.cfg.cache_config.rdma_comm_ports = self.cfg.cache_config.rdma_comm_ports[start_pos:end_pos]
self.cfg.local_device_ids = self.cfg.device_ids.split(",")[start_pos:end_pos]
self.cfg.parallel_config.local_data_parallel_id = local_data_parallel_id
@@ -78,9 +79,11 @@ class ExpertService:
cfg.splitwise_role,
local_data_parallel_id,
)
if cfg.splitwise_role != "mixed":
if len(self.cfg.cache_config.pd_comm_port) == 1:
self.cfg.cache_config.pd_comm_port[0] = int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
self.cfg.cache_config.pd_comm_port[0] = (
int(self.cfg.cache_config.pd_comm_port[0]) + local_data_parallel_id
)
else:
self.cfg.cache_config.pd_comm_port = [self.cfg.cache_config.pd_comm_port[local_data_parallel_id]]
@@ -119,15 +122,16 @@ class ExpertService:
start_time = time.time()
llm_logger.info(f"start expert service {local_data_parallel_id}")
if self.cfg.splitwise_role != "mixed":
self.cache_manager_processes = self.resource_manager.cache_manager.launch_cache_manager(
cache_config=self.cfg.cache_config,
tensor_parallel_size=self.cfg.tensor_parallel_size,
device_ids=self.cfg.local_device_ids,
pod_ip=self.cfg.master_ip,
pod_ip=self.cfg.pod_ips[0],
engine_worker_queue_port=self.cfg.engine_worker_queue_port,
pid_suffix=f"{local_data_parallel_id}_{ipc_signal_suffix}",
)
self.split_mode_get_tasks()
self.insert_task_to_worker_thread = threading.Thread(target=self._insert_task_to_worker, args=())
self.insert_task_to_worker_thread.daemon = True
@@ -138,8 +142,6 @@ class ExpertService:
self.token_processor.run()
self.split_mode_get_tasks()
self.cfg.init_cache_info()
role = self.cfg.splitwise_role
@@ -321,13 +323,13 @@ class ExpertService:
else:
is_prefill = True
self.token_processor.number_of_input_tokens += tasks[i].prompt_token_ids_len
if is_decode or is_prefill:
self.split_connector.send_cache_infos(tasks, current_id)
for task in tasks:
task.infer_start_time = time.time()
if not is_decode:
llm_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
if not is_prefill:
if not is_prefill and self.cfg.cache_config.enable_chunked_prefill:
if not self.cfg.enable_mm:
self.update_requests_chunk_size(tasks)
else:

View File

@@ -283,14 +283,15 @@ class PaddleDisWorkerProc:
paddle.distributed.barrier()
self.insert_step = False
self.worker_healthy_live_signal.value[self.local_rank] = int(time.time())
self.worker_healthy_live_signal.value[self.local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue
if self.local_rank % mp_num_per_node == 0:
if self.task_queue.num_tasks() > 0:
# VL only support 1 batch to prefill
if not self.fd_config.model_config.enable_mm or not self.worker.exist_prefill():
if self.nnode > 1:
if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node:
self.task_queue.read_finish_flag.set(1)
else:
self.exist_task_signal.value[self.fd_config.parallel_config.expert_parallel_rank] = 1