remove dev sync in prefill (#4598)

This commit is contained in:
周周周
2025-10-27 19:54:43 +08:00
committed by GitHub
parent 64d1aa973b
commit 3729e910a6
2 changed files with 6 additions and 3 deletions

View File

@@ -435,6 +435,7 @@ class EPPrefillRunner(EPRunner):
x: paddle.Tensor, x: paddle.Tensor,
topk_idx: paddle.Tensor, topk_idx: paddle.Tensor,
topk_weights: paddle.Tensor, topk_weights: paddle.Tensor,
expert_alignment: int = 1,
*args, *args,
**kwargs, **kwargs,
): ):
@@ -461,6 +462,7 @@ class EPPrefillRunner(EPRunner):
"async_finish": self.ep_engine.async_finish, "async_finish": self.ep_engine.async_finish,
"topk_idx": topk_idx, "topk_idx": topk_idx,
"topk_weights": topk_weights, "topk_weights": topk_weights,
"expert_alignment": expert_alignment,
} }
return buffer.dispatch(**dispatch_args) return buffer.dispatch(**dispatch_args)

View File

@@ -335,7 +335,9 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
recv_num_tokens_per_expert_list, recv_num_tokens_per_expert_list,
handle, handle,
_, _,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor) ) = self.ep_prefill_runner.dispatch(
x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128
)
token_all_num = sum(recv_num_tokens_per_expert_list) token_all_num = sum(recv_num_tokens_per_expert_list)
@@ -345,7 +347,6 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
(recv_x, recv_x_scale) = recv_x (recv_x, recv_x_scale) = recv_x
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
( (
permute_input, permute_input,
@@ -365,7 +366,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
token_nums_this_rank[0], token_nums_this_rank[0],
token_nums_this_rank[1], token_nums_this_rank[1],
True, # use_in_ep True, # use_in_ep
token_nums_this_rank_padded, token_all_num,
) )
permute_scale = permute_scale.transpose([1, 0]).contiguous() permute_scale = permute_scale.transpose([1, 0]).contiguous()