fix ep prefill (#2762)

This commit is contained in:
RichardWooSJTU
2025-07-09 14:03:05 +08:00
committed by GitHub
parent c4718fd693
commit fee544e808
7 changed files with 66 additions and 32 deletions

View File

@@ -144,7 +144,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
if token_all_num > 0:
logger.info(f"token_all_num {token_all_num}")
(recv_x, recv_x_scale) = recv_x
tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
(
permute_input,
permute_scale,
@@ -160,8 +163,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
recv_x_scale,
recv_topk_idx,
recv_topk_weights,
tmp[0],
tmp[1]
token_nums_this_rank[0],
token_nums_this_rank[1],
True, # use_in_ep
token_nums_this_rank_padded,
)
permute_scale = permute_scale.transpose([1, 0]).contiguous()
@@ -328,6 +333,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
topk_weights,
tmp[0],
tmp[1],
False, # use_in_ep
-1,
)
permute_scale = permute_scale.transpose([1, 0]).contiguous()