[BugFix] BF16 MoE Cutlass Backend Support EP (#5242)

This commit is contained in:
chen
2025-11-26 19:16:22 +08:00
committed by GitHub
parent ba915e03e1
commit 209970836e
4 changed files with 22 additions and 3 deletions

View File

@@ -304,6 +304,8 @@ class ModelConfig:
if hasattr(self, "num_experts") and getattr(self, "moe_num_experts") is None:
self.moe_num_experts = self.num_experts
if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None:
self.moe_num_experts = self.n_routed_experts
def read_from_env(self):
"""

View File

@@ -206,7 +206,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
tmp_ffn_out = recv_x
# 4. EP combine
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
if self.ep_prefill_runner.ep_engine.async_finish:
event.current_stream_wait()
return tmp_ffn_out
def apply_ep_decode(
self,
@@ -242,7 +245,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
num_local_experts, max_num, _ = permute_input.shape
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4", "w16a16"]:
expert_idx_per_token = None
else:
raise NotImplementedError

View File

@@ -808,7 +808,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
N=hidden_size,
K=moe_intermediate_size,
stride_am=x_q.strides[0],
stride_ak=x_scale.strides[1],
stride_ak=x_q.strides[1],
stride_be=layer.down_proj_weight.strides[0],
stride_bk=layer.down_proj_weight.strides[2],
stride_bn=layer.down_proj_weight.strides[1],

View File

@@ -494,6 +494,20 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
return logits
def empty_input_forward(self):
"""
empty_input_forward
"""
fake_hidden_states = paddle.ones(
shape=[1, self.fd_config.model_config.hidden_size],
dtype=paddle.get_default_dtype(),
)
for i in range(
self.fd_config.model_config.first_k_dense_replace,
self.fd_config.model_config.num_hidden_layers,
):
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
def forward(
self,
ids_remove_padding: paddle.Tensor,