[Optimization] EP empty_input_forward Remove Communication (#5254)

This commit is contained in:
chen
2025-12-01 21:10:40 +08:00
committed by GitHub
parent b0113cb0fc
commit aa35ce449d
4 changed files with 11 additions and 4 deletions

View File

@@ -58,6 +58,11 @@
__VA_ARGS__ \
break; \
} \
case 20: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 20; \
__VA_ARGS__ \
break; \
} \
case 32: { \
constexpr size_t NUM_EXPERTS_PER_RANK = 32; \
__VA_ARGS__ \

View File

@@ -146,8 +146,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
_,
event,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
token_all_num = sum(recv_num_tokens_per_expert_list)
# 3. Compute ffn

View File

@@ -498,8 +498,8 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
"""
empty_input_forward
"""
fake_hidden_states = paddle.ones(
shape=[1, self.fd_config.model_config.hidden_size],
fake_hidden_states = paddle.empty(
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(

View File

@@ -421,7 +421,7 @@ class Qwen3MoeForCausalLM(ModelForCasualLM):
empty_input_forward
"""
fake_hidden_states = paddle.empty(
shape=[1, self.fd_config.model_config.hidden_size],
shape=[0, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(