diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 1733ad5de..745e368ae 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -219,7 +219,7 @@ class DeepEPEngine: ep_rank: int, splitwise_role: str, moe_phase: MoEPhase, - async_finish: bool = False, + async_finish: bool = True, group=None, use_internode_ll_two_stage: bool = False, top_k: int = 8, @@ -532,8 +532,8 @@ class EPPrefillRunner(EPRunner): num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, - _, - ) = buffer.get_dispatch_layout(topk_idx, self.num_experts) + event, + ) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish) x_scale_tensor = kwargs.get("x_scale_tensor", None) dispatch_args = { @@ -547,6 +547,7 @@ class EPPrefillRunner(EPRunner): "topk_idx": topk_idx, "topk_weights": topk_weights, "expert_alignment": expert_alignment, + "previous_event": event, } return buffer.dispatch(**dispatch_args) @@ -567,8 +568,8 @@ class EPPrefillRunner(EPRunner): "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, } - fused_moe_out, _, _ = buffer.combine(**combine_args) - return fused_moe_out + fused_moe_out, _, event = buffer.combine(**combine_args) + return fused_moe_out, event class EPDecoderRunner(EPRunner): diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 4f7a21923..0a1a8a8ee 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -326,10 +326,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): recv_topk_weights, recv_num_tokens_per_expert_list, handle, - _, + event, ) = self.ep_prefill_runner.dispatch( x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128 ) + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() token_all_num = sum(recv_num_tokens_per_expert_list) @@ -410,7 +412,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase): tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16) # 5. EP combine - return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) + tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights) + + if self.ep_prefill_runner.ep_engine.async_finish: + event.current_stream_wait() + + return tmp_ffn_out def apply_ep_decode( self,