[NewFeture]add ep rollout model init and update/clear ep buffer (#4039)

* fix gid

* merge

* fix test

* fix bug

* fix

* fix ci
This commit is contained in:
gaoziyuan
2025-09-17 20:24:53 +08:00
committed by GitHub
parent 0d3a57a2c6
commit 896e3bb606
12 changed files with 348 additions and 293 deletions

View File

@@ -338,20 +338,26 @@ class ParallelConfig:
else:
self.pd_disaggregation_mode = "None"
def set_tp_group(self):
def set_communicate_group(self):
# different tp group id
# prevent different tp_groups using the same group_id
tp_gid_offset = envs.FD_TP_GROUP_GID_OFFSET
dist.collective._set_custom_gid(self.data_parallel_rank + tp_gid_offset)
self.tp_group = dist.new_group(
range(
self.data_parallel_rank * self.tensor_parallel_size,
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
)
)
dist.collective._set_custom_gid(None)
# same ep group id
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
if self.enable_expert_parallel:
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
self.ep_group = dist.new_group(range(self.expert_parallel_size))
dist.collective._set_custom_gid(None)
logger.info(
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
)
@@ -833,6 +839,7 @@ class LoadConfig:
load_strategy: Specifies the weight loading method when enabled:
- 'ipc': Real-time IPC streaming with automatic resharding
- 'ipc_snapshot': Load from disk snapshot of IPC weights
- 'meta': Only model meta messages
- None: No dynamic loading
"""
@@ -843,7 +850,7 @@ class LoadConfig:
self.load_choices: Union[str, LoadChoices] = LoadChoices.DEFAULT.value
self.use_fastsafetensor = int(envs.FD_USE_FASTSAFETENSOR) == 1
self.dynamic_load_weight: bool = False
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot"]] = None
self.load_strategy: Optional[Literal["ipc", "ipc_snapshot", "meta", "normal"]] = "normal"
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
@@ -1201,12 +1208,10 @@ class FDConfig:
num_ranks = self.parallel_config.tensor_parallel_size * self.parallel_config.data_parallel_size
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
if num_ranks > self.max_chips_per_node:
if num_ranks > self.max_chips_per_node and self.load_config.load_strategy != "meta":
self.worker_num_per_node = self.max_chips_per_node
nnode = ceil_div(num_ranks, self.worker_num_per_node)
assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
# assert nnode == self.nnode, f"nnode: {nnode}, but got {self.nnode}"
else:
self.worker_num_per_node = num_ranks