[EP] fix several bugs in data parallel (#4657)

* Simplify profiling block setup in expert_service.py

Refactor profiling block initialization to avoid duplication.

* Update common_engine.py
This commit is contained in:
ltd0924
2025-10-30 09:50:49 +08:00
committed by GitHub
parent dab04ab413
commit 50be19a88a
2 changed files with 17 additions and 17 deletions

View File

@@ -303,7 +303,8 @@ class EngineService:
client_id=0,
local_data_parallel_size=self.cfg.parallel_config.data_parallel_size,
local_data_parallel_id=min(
self.cfg.worker_num_per_node * self.cfg.node_rank + self.cfg.parallel_config.local_data_parallel_id,
self.cfg.worker_num_per_node // self.cfg.parallel_config.tensor_parallel_size * self.cfg.node_rank
+ self.cfg.parallel_config.local_data_parallel_id,
self.cfg.parallel_config.data_parallel_size - 1,
),
)

View File

@@ -130,29 +130,28 @@ class ExpertService:
create=False,
)
self.launched_expert_service_signal.value[local_rank] = 1
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:
try:
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=int(self.cfg.parallel_config.engine_worker_queue_port[0]),
create=False,
)
break
except:
time.sleep(1)
self.reset_kvcache_blocks()
if self.cfg.scheduler_config.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching:
if self.do_profile:
get_profile_block_num = np.zeros([1], dtype=np.int32)
while True:
try:
self.get_profile_block_num_signal = IPCSignal(
name="get_profile_block_num",
array=get_profile_block_num,
dtype=np.int32,
suffix=int(self.cfg.parallel_config.engine_worker_queue_port[0]),
create=False,
)
break
except:
time.sleep(1)
self.reset_kvcache_blocks()
ipc_signal_suffix_cache = self.cfg.parallel_config.engine_worker_queue_port[local_data_parallel_id]
self.cache_manager_processes = self.engine.start_cache_service(
self.cfg.local_device_ids,
ipc_signal_suffix_cache,
)
console_logger.info(
f"Worker processes(rank {local_rank}) are launched with {time.time() - start_time} seconds."
)