[Feature] support model weight update in ep (#3765)

* support model weight update in ep

* support model weight update in ep

* support model weight update in ep

* support model weight update in ep

* Update fused_moe_backend_base.py

* Update worker_process.py

* Update worker_process.py

* Update dynamic_weight_manager.py
This commit is contained in:
ltd0924
2025-09-02 17:16:03 +08:00
committed by GitHub
parent 1908465542
commit 905d89e42f
5 changed files with 47 additions and 22 deletions

View File

@@ -78,6 +78,7 @@ class DeepEPEngine:
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
group=None,
):
"""
Initialize the DeepEP engine.
@@ -90,7 +91,9 @@ class DeepEPEngine:
num_experts: The number of experts.
"""
# TODO(@wufeisheng): Support configurable EP size
self.group = paddle.distributed.new_group(range(ep_size))
if group is None:
group = paddle.distributed.new_group(range(ep_size))
self.group = group
self.ep_size = ep_size
self.rank_id = ep_rank
self.hidden = hidden
@@ -277,6 +280,7 @@ class EPRunner:
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
ep_group=None,
):
self.top_k = top_k
self.num_experts = num_experts
@@ -289,6 +293,7 @@ class EPRunner:
ep_rank=ep_rank,
splitwise_role=splitwise_role,
moe_phase=moe_phase,
group=ep_group,
)
def moe_select(self, layer: nn.Layer, gate_out: paddle.Tensor):
@@ -367,6 +372,7 @@ class EPPrefillRunner(EPRunner):
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
ep_group=None,
moe_phase: MoEPhase = MoEPhase("prefill"),
):
super().__init__(
@@ -379,6 +385,7 @@ class EPPrefillRunner(EPRunner):
ep_size=ep_size,
ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num,
ep_group=ep_group,
)
def dispatch(
@@ -445,6 +452,7 @@ class EPDecoderRunner(EPRunner):
ep_size: int = 1,
ep_rank: int = 0,
redundant_experts_num: int = 0,
ep_group=None,
moe_phase: MoEPhase = MoEPhase("decode"),
):
super().__init__(
@@ -457,6 +465,7 @@ class EPDecoderRunner(EPRunner):
ep_size=ep_size,
ep_rank=ep_rank,
redundant_experts_num=redundant_experts_num,
ep_group=ep_group,
)
def dispatch(