[Fix] fix expert_parallel bug in decoder stage (#2848)

This commit is contained in:
freeliuzc
2025-07-16 11:08:18 +08:00
committed by GitHub
parent 17314ee126
commit 2d1184aefe
2 changed files with 3 additions and 3 deletions

View File

@@ -49,7 +49,7 @@ class MoEMethodBase(QuantMethodBase):
from .ep import EPDecoderRunner
self.ep_decoder_runner = EPDecoderRunner(
layer.top_k, layer.hidden_size, layer.num_experts,
layer.model_config.num_max_dispatch_tokens_per_rank,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.ep_size, layer.ep_rank)
else:
from .ep import EPPrefillRunner

View File

@@ -241,7 +241,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
[
layer.num_local_experts,
layer.ep_size *
layer.model_config.num_max_dispatch_tokens_per_rank,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.moe_intermediate_size * 2,
],
dtype=paddle.bfloat16,
@@ -251,7 +251,7 @@ class DeepGemmFusedMoeMethod(MoEMethodBase):
[
layer.num_local_experts,
layer.ep_size *
layer.model_config.num_max_dispatch_tokens_per_rank,
layer.fd_config.model_config.num_max_dispatch_tokens_per_rank,
layer.hidden_size,
],
dtype=paddle.bfloat16,