mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 11:02:01 +08:00
fix ep prefill (#2762)
This commit is contained in:
@@ -144,7 +144,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
if token_all_num > 0:
|
||||
logger.info(f"token_all_num {token_all_num}")
|
||||
(recv_x, recv_x_scale) = recv_x
|
||||
tmp = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
|
||||
token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts)
|
||||
token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist())
|
||||
|
||||
(
|
||||
permute_input,
|
||||
permute_scale,
|
||||
@@ -160,8 +163,10 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
recv_x_scale,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
tmp[0],
|
||||
tmp[1]
|
||||
token_nums_this_rank[0],
|
||||
token_nums_this_rank[1],
|
||||
True, # use_in_ep
|
||||
token_nums_this_rank_padded,
|
||||
)
|
||||
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
@@ -328,6 +333,8 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
|
||||
topk_weights,
|
||||
tmp[0],
|
||||
tmp[1],
|
||||
False, # use_in_ep
|
||||
-1,
|
||||
)
|
||||
|
||||
permute_scale = permute_scale.transpose([1, 0]).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user