[DeepEP] support P async_finish (#4899)

This commit is contained in:
周周周
2025-11-10 18:24:02 +08:00
committed by GitHub
parent 78895e2c7d
commit 54536267db
2 changed files with 15 additions and 7 deletions

View File

@@ -219,7 +219,7 @@ class DeepEPEngine:
ep_rank: int,
splitwise_role: str,
moe_phase: MoEPhase,
async_finish: bool = False,
async_finish: bool = True,
group=None,
use_internode_ll_two_stage: bool = False,
top_k: int = 8,
@@ -532,8 +532,8 @@ class EPPrefillRunner(EPRunner):
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
_,
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
event,
) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish)
x_scale_tensor = kwargs.get("x_scale_tensor", None)
dispatch_args = {
@@ -547,6 +547,7 @@ class EPPrefillRunner(EPRunner):
"topk_idx": topk_idx,
"topk_weights": topk_weights,
"expert_alignment": expert_alignment,
"previous_event": event,
}
return buffer.dispatch(**dispatch_args)
@@ -567,8 +568,8 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish,
"topk_weights": recv_topk_weights,
}
fused_moe_out, _, _ = buffer.combine(**combine_args)
return fused_moe_out
fused_moe_out, _, event = buffer.combine(**combine_args)
return fused_moe_out, event
class EPDecoderRunner(EPRunner):

View File

@@ -326,10 +326,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
_,
event,
) = self.ep_prefill_runner.dispatch(
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128
)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
token_all_num = sum(recv_num_tokens_per_expert_list)
@@ -410,7 +412,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
# 5. EP combine
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
return tmp_ffn_out
def apply_ep_decode(
self,