mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[DeepEP] support P async_finish (#4899)
This commit is contained in:
@@ -219,7 +219,7 @@ class DeepEPEngine:
|
||||
ep_rank: int,
|
||||
splitwise_role: str,
|
||||
moe_phase: MoEPhase,
|
||||
async_finish: bool = False,
|
||||
async_finish: bool = True,
|
||||
group=None,
|
||||
use_internode_ll_two_stage: bool = False,
|
||||
top_k: int = 8,
|
||||
@@ -532,8 +532,8 @@ class EPPrefillRunner(EPRunner):
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
_,
|
||||
) = buffer.get_dispatch_layout(topk_idx, self.num_experts)
|
||||
event,
|
||||
) = buffer.get_dispatch_layout(topk_idx, self.num_experts, async_finish=self.ep_engine.async_finish)
|
||||
|
||||
x_scale_tensor = kwargs.get("x_scale_tensor", None)
|
||||
dispatch_args = {
|
||||
@@ -547,6 +547,7 @@ class EPPrefillRunner(EPRunner):
|
||||
"topk_idx": topk_idx,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_alignment": expert_alignment,
|
||||
"previous_event": event,
|
||||
}
|
||||
return buffer.dispatch(**dispatch_args)
|
||||
|
||||
@@ -567,8 +568,8 @@ class EPPrefillRunner(EPRunner):
|
||||
"async_finish": self.ep_engine.async_finish,
|
||||
"topk_weights": recv_topk_weights,
|
||||
}
|
||||
fused_moe_out, _, _ = buffer.combine(**combine_args)
|
||||
return fused_moe_out
|
||||
fused_moe_out, _, event = buffer.combine(**combine_args)
|
||||
return fused_moe_out, event
|
||||
|
||||
|
||||
class EPDecoderRunner(EPRunner):
|
||||
|
||||
@@ -326,10 +326,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
recv_topk_weights,
|
||||
recv_num_tokens_per_expert_list,
|
||||
handle,
|
||||
_,
|
||||
event,
|
||||
) = self.ep_prefill_runner.dispatch(
|
||||
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128
|
||||
)
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
token_all_num = sum(recv_num_tokens_per_expert_list)
|
||||
|
||||
@@ -410,7 +412,12 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
tmp_ffn_out = paddle.cast(recv_x[0], paddle.bfloat16)
|
||||
|
||||
# 5. EP combine
|
||||
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
|
||||
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
|
||||
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user