[NewFeture]add ep rollout model init and update/clear ep buffer (#3927)

* add ep rollout model init && add deep update/clear

* fix test
This commit is contained in:
gaoziyuan
2025-09-12 14:15:13 +08:00
committed by GitHub
parent c64ceac34d
commit 10768a4d79
13 changed files with 364 additions and 304 deletions

View File

@@ -259,7 +259,7 @@ class PaddleDisWorkerProc:
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
while True:
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if local_rank == 0:
if self.model_weights_status.value[0] != 0:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
@@ -272,7 +272,7 @@ class PaddleDisWorkerProc:
self.worker_healthy_live_signal.value[local_rank % self.max_chips_per_node] = int(time.time())
# The first worker detects whether there are tasks in the task queue
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
if local_rank == 0:
if self.task_queue.num_tasks() > 0:
# VL only support 1 batch to prefill
if envs.ENABLE_V1_KVCACHE_SCHEDULER or not (
@@ -584,7 +584,7 @@ def parse_args():
parser.add_argument(
"--load_strategy",
type=str,
choices=["ipc", "ipc_snapshot"],
choices=["ipc", "ipc_snapshot", "meta", "normal"],
default="ipc_snapshot",
help="Weight loading method when dynamic loading is enabled: "
"'ipc': real-time IPC streaming with automatic resharding, "
@@ -663,10 +663,11 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
parallel_config.num_experts_per_rank = num_experts_per_rank
parallel_config.num_experts_start_offset = num_experts_start_offset
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
parallel_config.local_data_parallel_id
]
parallel_config.set_tp_group()
if args.load_strategy != "meta":
parallel_config.engine_worker_queue_port = parallel_config.engine_worker_queue_port[
parallel_config.local_data_parallel_id
]
parallel_config.set_communicate_group()
load_config = LoadConfig(vars(args))