[BugFix] fix VL fp8 bug when moe token_num is 0 (#4928)

* [BugFix] fix VL fp8 bug when moe token_num is 0

* fix bug

* format

* fix bug
This commit is contained in:
ming1753
2025-11-12 21:19:36 +08:00
committed by GitHub
parent c8140326fa
commit 3148dbca06

View File

@@ -287,8 +287,10 @@ class TritonWeightOnlyMoEMethod(QuantMethodBase):
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
if token_num == 0:
return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype)
gate_out = gate(x.cast("float32"))
top_k = layer.top_k
num_local_experts = layer.num_local_experts
top_k = layer.top_k
@@ -669,8 +671,10 @@ class Wfp8Afp8MoEMethod(QuantMethodBase):
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
if token_num == 0:
return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype)
gate_out = gate(x.cast("float32"))
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
@@ -959,8 +963,10 @@ class TensorWiseFP8MoEMethod(QuantMethodBase):
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
if token_num == 0:
return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype)
gate_out = gate(x.cast("float32"))
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size
@@ -1480,8 +1486,10 @@ class BlockWiseFP8MoEMethod(QuantMethodBase):
"""
Triton compute Fused MoE.
"""
gate_out = gate(x.cast("float32"))
token_num = x.shape[0]
if token_num == 0:
return paddle.zeros([token_num, layer.hidden_size], dtype=x.dtype)
gate_out = gate(x.cast("float32"))
top_k = layer.top_k
num_local_experts = layer.num_local_experts
moe_intermediate_size = layer.moe_intermediate_size