mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[BugFix] BF16 MoE Cutlass Backend Support EP (#5242)
This commit is contained in:
@@ -304,6 +304,8 @@ class ModelConfig:
|
||||
|
||||
if hasattr(self, "num_experts") and getattr(self, "moe_num_experts") is None:
|
||||
self.moe_num_experts = self.num_experts
|
||||
if hasattr(self, "n_routed_experts") and getattr(self, "moe_num_experts") is None:
|
||||
self.moe_num_experts = self.n_routed_experts
|
||||
|
||||
def read_from_env(self):
|
||||
"""
|
||||
|
||||
@@ -206,7 +206,10 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
tmp_ffn_out = recv_x
|
||||
|
||||
# 4. EP combine
|
||||
return self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
|
||||
tmp_ffn_out, event = self.ep_prefill_runner.combine(tmp_ffn_out, handle, recv_topk_weights)
|
||||
if self.ep_prefill_runner.ep_engine.async_finish:
|
||||
event.current_stream_wait()
|
||||
return tmp_ffn_out
|
||||
|
||||
def apply_ep_decode(
|
||||
self,
|
||||
@@ -242,7 +245,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if self.moe_quant_type == "w4a8" or self.moe_quant_type == "w4afp8":
|
||||
num_local_experts, max_num, _ = permute_input.shape
|
||||
expert_idx_per_token = paddle.arange(num_local_experts)[:, None].tile([1, max_num])
|
||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4"]:
|
||||
elif self.moe_quant_type in ["weight_only_int8", "weight_only_int4", "w16a16"]:
|
||||
expert_idx_per_token = None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -808,7 +808,7 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
|
||||
N=hidden_size,
|
||||
K=moe_intermediate_size,
|
||||
stride_am=x_q.strides[0],
|
||||
stride_ak=x_scale.strides[1],
|
||||
stride_ak=x_q.strides[1],
|
||||
stride_be=layer.down_proj_weight.strides[0],
|
||||
stride_bk=layer.down_proj_weight.strides[2],
|
||||
stride_bn=layer.down_proj_weight.strides[1],
|
||||
|
||||
@@ -494,6 +494,20 @@ class Glm4MoeForCausalLM(ModelForCasualLM):
|
||||
|
||||
return logits
|
||||
|
||||
def empty_input_forward(self):
|
||||
"""
|
||||
empty_input_forward
|
||||
"""
|
||||
fake_hidden_states = paddle.ones(
|
||||
shape=[1, self.fd_config.model_config.hidden_size],
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
for i in range(
|
||||
self.fd_config.model_config.first_k_dense_replace,
|
||||
self.fd_config.model_config.num_hidden_layers,
|
||||
):
|
||||
self.model.layers[i].mlp.experts(fake_hidden_states, self.model.layers[i].mlp.gate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ids_remove_padding: paddle.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user